From 26c28d5d2bef70e701dd13805cd756596dc2ac52 Mon Sep 17 00:00:00 2001 From: mivanit Date: Sun, 16 Jun 2024 22:15:19 -0700 Subject: [PATCH 001/158] try running CI for python >=3.8 --- .github/workflows/checks.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 995dda85..c1f13d3c 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -30,6 +30,10 @@ jobs: strategy: matrix: versions: + - python: "3.8" + torch: "1.13.1" + - python: "3.9" + torch: "1.13.1" - python: "3.10" torch: "1.13.1" - python: "3.10" From cc768c726e4aafd2b485fb1ad7acb67b2b2b78a4 Mon Sep 17 00:00:00 2001 From: mivanit Date: Sun, 16 Jun 2024 22:17:14 -0700 Subject: [PATCH 002/158] update pyproject.toml for python >=3.8 --- pyproject.toml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index dd445a01..40597bf0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,8 +6,11 @@ license = "GPL-3.0-only" authors = ["mivanit "] readme = "README.md" classifiers=[ + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Development Status :: 4 - Beta", "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", "Operating System :: OS Independent", @@ -15,7 +18,7 @@ classifiers=[ repository = "https://github.com/mivanit/muutils" [tool.poetry.dependencies] -python = "^3.10" +python = "^3.8" numpy = { version = "^1.22.4", optional = true } torch = { version = ">=1.13.1", optional = true } jaxtyping = { version = "^0.2.12", optional = true } From 305a081fa8da98e174d3caee27d435f1d6eec93a Mon Sep 17 00:00:00 2001 From: mivanit Date: Mon, 17 Jun 2024 16:44:16 -0700 Subject: [PATCH 003/158] fix some deps --- poetry.lock | 556 +++++++++++++++++++++++++++++++------------------ pyproject.toml | 3 +- 2 files changed, 361 insertions(+), 198 deletions(-) diff --git a/poetry.lock b/poetry.lock index 3e8eddd8..69b82465 100644 --- a/poetry.lock +++ b/poetry.lock @@ -23,7 +23,7 @@ wrapt = [ name = "asttokens" version = "2.4.1" description = "Annotate AST trees with source code positions" -optional = false +optional = true python-versions = "*" files = [ {file = "asttokens-2.4.1-py2.py3-none-any.whl", hash = "sha256:051ed49c3dcae8913ea7cd08e46a606dba30b79993209636c4875bc1d637bc24"}, @@ -110,66 +110,136 @@ files = [ [[package]] name = "contourpy" -version = "1.2.1" +version = "1.1.0" +description = "Python library for calculating contours of 2D quadrilateral grids" +optional = false +python-versions = ">=3.8" +files = [ + {file = "contourpy-1.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:89f06eff3ce2f4b3eb24c1055a26981bffe4e7264acd86f15b97e40530b794bc"}, + {file = "contourpy-1.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:dffcc2ddec1782dd2f2ce1ef16f070861af4fb78c69862ce0aab801495dda6a3"}, + {file = "contourpy-1.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25ae46595e22f93592d39a7eac3d638cda552c3e1160255258b695f7b58e5655"}, + {file = "contourpy-1.1.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:17cfaf5ec9862bc93af1ec1f302457371c34e688fbd381f4035a06cd47324f48"}, + {file = "contourpy-1.1.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:18a64814ae7bce73925131381603fff0116e2df25230dfc80d6d690aa6e20b37"}, + {file = "contourpy-1.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90c81f22b4f572f8a2110b0b741bb64e5a6427e0a198b2cdc1fbaf85f352a3aa"}, + {file = "contourpy-1.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:53cc3a40635abedbec7f1bde60f8c189c49e84ac180c665f2cd7c162cc454baa"}, + {file = "contourpy-1.1.0-cp310-cp310-win32.whl", hash = "sha256:9b2dd2ca3ac561aceef4c7c13ba654aaa404cf885b187427760d7f7d4c57cff8"}, + {file = "contourpy-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:1f795597073b09d631782e7245016a4323cf1cf0b4e06eef7ea6627e06a37ff2"}, + {file = "contourpy-1.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0b7b04ed0961647691cfe5d82115dd072af7ce8846d31a5fac6c142dcce8b882"}, + {file = "contourpy-1.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:27bc79200c742f9746d7dd51a734ee326a292d77e7d94c8af6e08d1e6c15d545"}, + {file = "contourpy-1.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:052cc634bf903c604ef1a00a5aa093c54f81a2612faedaa43295809ffdde885e"}, + {file = "contourpy-1.1.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9382a1c0bc46230fb881c36229bfa23d8c303b889b788b939365578d762b5c18"}, + {file = "contourpy-1.1.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e5cec36c5090e75a9ac9dbd0ff4a8cf7cecd60f1b6dc23a374c7d980a1cd710e"}, + {file = "contourpy-1.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f0cbd657e9bde94cd0e33aa7df94fb73c1ab7799378d3b3f902eb8eb2e04a3a"}, + {file = "contourpy-1.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:181cbace49874f4358e2929aaf7ba84006acb76694102e88dd15af861996c16e"}, + {file = "contourpy-1.1.0-cp311-cp311-win32.whl", hash = "sha256:edb989d31065b1acef3828a3688f88b2abb799a7db891c9e282df5ec7e46221b"}, + {file = "contourpy-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:fb3b7d9e6243bfa1efb93ccfe64ec610d85cfe5aec2c25f97fbbd2e58b531256"}, + {file = "contourpy-1.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bcb41692aa09aeb19c7c213411854402f29f6613845ad2453d30bf421fe68fed"}, + {file = "contourpy-1.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5d123a5bc63cd34c27ff9c7ac1cd978909e9c71da12e05be0231c608048bb2ae"}, + {file = "contourpy-1.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:62013a2cf68abc80dadfd2307299bfa8f5aa0dcaec5b2954caeb5fa094171103"}, + {file = "contourpy-1.1.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0b6616375d7de55797d7a66ee7d087efe27f03d336c27cf1f32c02b8c1a5ac70"}, + {file = "contourpy-1.1.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:317267d915490d1e84577924bd61ba71bf8681a30e0d6c545f577363157e5e94"}, + {file = "contourpy-1.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d551f3a442655f3dcc1285723f9acd646ca5858834efeab4598d706206b09c9f"}, + {file = "contourpy-1.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e7a117ce7df5a938fe035cad481b0189049e8d92433b4b33aa7fc609344aafa1"}, + {file = "contourpy-1.1.0-cp38-cp38-win32.whl", hash = "sha256:108dfb5b3e731046a96c60bdc46a1a0ebee0760418951abecbe0fc07b5b93b27"}, + {file = "contourpy-1.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:d4f26b25b4f86087e7d75e63212756c38546e70f2a92d2be44f80114826e1cd4"}, + {file = "contourpy-1.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bc00bb4225d57bff7ebb634646c0ee2a1298402ec10a5fe7af79df9a51c1bfd9"}, + {file = "contourpy-1.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:189ceb1525eb0655ab8487a9a9c41f42a73ba52d6789754788d1883fb06b2d8a"}, + {file = "contourpy-1.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9f2931ed4741f98f74b410b16e5213f71dcccee67518970c42f64153ea9313b9"}, + {file = "contourpy-1.1.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:30f511c05fab7f12e0b1b7730ebdc2ec8deedcfb505bc27eb570ff47c51a8f15"}, + {file = "contourpy-1.1.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:143dde50520a9f90e4a2703f367cf8ec96a73042b72e68fcd184e1279962eb6f"}, + {file = "contourpy-1.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e94bef2580e25b5fdb183bf98a2faa2adc5b638736b2c0a4da98691da641316a"}, + {file = "contourpy-1.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ed614aea8462735e7d70141374bd7650afd1c3f3cb0c2dbbcbe44e14331bf002"}, + {file = "contourpy-1.1.0-cp39-cp39-win32.whl", hash = "sha256:71551f9520f008b2950bef5f16b0e3587506ef4f23c734b71ffb7b89f8721999"}, + {file = "contourpy-1.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:438ba416d02f82b692e371858143970ed2eb6337d9cdbbede0d8ad9f3d7dd17d"}, + {file = "contourpy-1.1.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a698c6a7a432789e587168573a864a7ea374c6be8d4f31f9d87c001d5a843493"}, + {file = "contourpy-1.1.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:397b0ac8a12880412da3551a8cb5a187d3298a72802b45a3bd1805e204ad8439"}, + {file = "contourpy-1.1.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:a67259c2b493b00e5a4d0f7bfae51fb4b3371395e47d079a4446e9b0f4d70e76"}, + {file = "contourpy-1.1.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:2b836d22bd2c7bb2700348e4521b25e077255ebb6ab68e351ab5aa91ca27e027"}, + {file = "contourpy-1.1.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:084eaa568400cfaf7179b847ac871582199b1b44d5699198e9602ecbbb5f6104"}, + {file = "contourpy-1.1.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:911ff4fd53e26b019f898f32db0d4956c9d227d51338fb3b03ec72ff0084ee5f"}, + {file = "contourpy-1.1.0.tar.gz", hash = "sha256:e53046c3863828d21d531cc3b53786e6580eb1ba02477e8681009b6aa0870b21"}, +] + +[package.dependencies] +numpy = ">=1.16" + +[package.extras] +bokeh = ["bokeh", "selenium"] +docs = ["furo", "sphinx-copybutton"] +mypy = ["contourpy[bokeh,docs]", "docutils-stubs", "mypy (==1.2.0)", "types-Pillow"] +test = ["Pillow", "contourpy[test-no-images]", "matplotlib"] +test-no-images = ["pytest", "pytest-cov", "wurlitzer"] + +[[package]] +name = "contourpy" +version = "1.1.1" description = "Python library for calculating contours of 2D quadrilateral grids" optional = false -python-versions = ">=3.9" -files = [ - {file = "contourpy-1.2.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bd7c23df857d488f418439686d3b10ae2fbf9bc256cd045b37a8c16575ea1040"}, - {file = "contourpy-1.2.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5b9eb0ca724a241683c9685a484da9d35c872fd42756574a7cfbf58af26677fd"}, - {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4c75507d0a55378240f781599c30e7776674dbaf883a46d1c90f37e563453480"}, - {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:11959f0ce4a6f7b76ec578576a0b61a28bdc0696194b6347ba3f1c53827178b9"}, - {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eb3315a8a236ee19b6df481fc5f997436e8ade24a9f03dfdc6bd490fea20c6da"}, - {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39f3ecaf76cd98e802f094e0d4fbc6dc9c45a8d0c4d185f0f6c2234e14e5f75b"}, - {file = "contourpy-1.2.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:94b34f32646ca0414237168d68a9157cb3889f06b096612afdd296003fdd32fd"}, - {file = "contourpy-1.2.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:457499c79fa84593f22454bbd27670227874cd2ff5d6c84e60575c8b50a69619"}, - {file = "contourpy-1.2.1-cp310-cp310-win32.whl", hash = "sha256:ac58bdee53cbeba2ecad824fa8159493f0bf3b8ea4e93feb06c9a465d6c87da8"}, - {file = "contourpy-1.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:9cffe0f850e89d7c0012a1fb8730f75edd4320a0a731ed0c183904fe6ecfc3a9"}, - {file = "contourpy-1.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6022cecf8f44e36af10bd9118ca71f371078b4c168b6e0fab43d4a889985dbb5"}, - {file = "contourpy-1.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ef5adb9a3b1d0c645ff694f9bca7702ec2c70f4d734f9922ea34de02294fdf72"}, - {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6150ffa5c767bc6332df27157d95442c379b7dce3a38dff89c0f39b63275696f"}, - {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4c863140fafc615c14a4bf4efd0f4425c02230eb8ef02784c9a156461e62c965"}, - {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:00e5388f71c1a0610e6fe56b5c44ab7ba14165cdd6d695429c5cd94021e390b2"}, - {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d4492d82b3bc7fbb7e3610747b159869468079fe149ec5c4d771fa1f614a14df"}, - {file = "contourpy-1.2.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:49e70d111fee47284d9dd867c9bb9a7058a3c617274900780c43e38d90fe1205"}, - {file = "contourpy-1.2.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b59c0ffceff8d4d3996a45f2bb6f4c207f94684a96bf3d9728dbb77428dd8cb8"}, - {file = "contourpy-1.2.1-cp311-cp311-win32.whl", hash = "sha256:7b4182299f251060996af5249c286bae9361fa8c6a9cda5efc29fe8bfd6062ec"}, - {file = "contourpy-1.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2855c8b0b55958265e8b5888d6a615ba02883b225f2227461aa9127c578a4922"}, - {file = "contourpy-1.2.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:62828cada4a2b850dbef89c81f5a33741898b305db244904de418cc957ff05dc"}, - {file = "contourpy-1.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:309be79c0a354afff9ff7da4aaed7c3257e77edf6c1b448a779329431ee79d7e"}, - {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e785e0f2ef0d567099b9ff92cbfb958d71c2d5b9259981cd9bee81bd194c9a4"}, - {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1cac0a8f71a041aa587410424ad46dfa6a11f6149ceb219ce7dd48f6b02b87a7"}, - {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:af3f4485884750dddd9c25cb7e3915d83c2db92488b38ccb77dd594eac84c4a0"}, - {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ce6889abac9a42afd07a562c2d6d4b2b7134f83f18571d859b25624a331c90b"}, - {file = "contourpy-1.2.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a1eea9aecf761c661d096d39ed9026574de8adb2ae1c5bd7b33558af884fb2ce"}, - {file = "contourpy-1.2.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:187fa1d4c6acc06adb0fae5544c59898ad781409e61a926ac7e84b8f276dcef4"}, - {file = "contourpy-1.2.1-cp312-cp312-win32.whl", hash = "sha256:c2528d60e398c7c4c799d56f907664673a807635b857df18f7ae64d3e6ce2d9f"}, - {file = "contourpy-1.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:1a07fc092a4088ee952ddae19a2b2a85757b923217b7eed584fdf25f53a6e7ce"}, - {file = "contourpy-1.2.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bb6834cbd983b19f06908b45bfc2dad6ac9479ae04abe923a275b5f48f1a186b"}, - {file = "contourpy-1.2.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1d59e739ab0e3520e62a26c60707cc3ab0365d2f8fecea74bfe4de72dc56388f"}, - {file = "contourpy-1.2.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd3db01f59fdcbce5b22afad19e390260d6d0222f35a1023d9adc5690a889364"}, - {file = "contourpy-1.2.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a12a813949e5066148712a0626895c26b2578874e4cc63160bb007e6df3436fe"}, - {file = "contourpy-1.2.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fe0ccca550bb8e5abc22f530ec0466136379c01321fd94f30a22231e8a48d985"}, - {file = "contourpy-1.2.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1d59258c3c67c865435d8fbeb35f8c59b8bef3d6f46c1f29f6123556af28445"}, - {file = "contourpy-1.2.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f32c38afb74bd98ce26de7cc74a67b40afb7b05aae7b42924ea990d51e4dac02"}, - {file = "contourpy-1.2.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d31a63bc6e6d87f77d71e1abbd7387ab817a66733734883d1fc0021ed9bfa083"}, - {file = "contourpy-1.2.1-cp39-cp39-win32.whl", hash = "sha256:ddcb8581510311e13421b1f544403c16e901c4e8f09083c881fab2be80ee31ba"}, - {file = "contourpy-1.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:10a37ae557aabf2509c79715cd20b62e4c7c28b8cd62dd7d99e5ed3ce28c3fd9"}, - {file = "contourpy-1.2.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a31f94983fecbac95e58388210427d68cd30fe8a36927980fab9c20062645609"}, - {file = "contourpy-1.2.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef2b055471c0eb466033760a521efb9d8a32b99ab907fc8358481a1dd29e3bd3"}, - {file = "contourpy-1.2.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:b33d2bc4f69caedcd0a275329eb2198f560b325605810895627be5d4b876bf7f"}, - {file = "contourpy-1.2.1.tar.gz", hash = "sha256:4d8908b3bee1c889e547867ca4cdc54e5ab6be6d3e078556814a22457f49423c"}, +python-versions = ">=3.8" +files = [ + {file = "contourpy-1.1.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:46e24f5412c948d81736509377e255f6040e94216bf1a9b5ea1eaa9d29f6ec1b"}, + {file = "contourpy-1.1.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0e48694d6a9c5a26ee85b10130c77a011a4fedf50a7279fa0bdaf44bafb4299d"}, + {file = "contourpy-1.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a66045af6cf00e19d02191ab578a50cb93b2028c3eefed999793698e9ea768ae"}, + {file = "contourpy-1.1.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4ebf42695f75ee1a952f98ce9775c873e4971732a87334b099dde90b6af6a916"}, + {file = "contourpy-1.1.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f6aec19457617ef468ff091669cca01fa7ea557b12b59a7908b9474bb9674cf0"}, + {file = "contourpy-1.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:462c59914dc6d81e0b11f37e560b8a7c2dbab6aca4f38be31519d442d6cde1a1"}, + {file = "contourpy-1.1.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6d0a8efc258659edc5299f9ef32d8d81de8b53b45d67bf4bfa3067f31366764d"}, + {file = "contourpy-1.1.1-cp310-cp310-win32.whl", hash = "sha256:d6ab42f223e58b7dac1bb0af32194a7b9311065583cc75ff59dcf301afd8a431"}, + {file = "contourpy-1.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:549174b0713d49871c6dee90a4b499d3f12f5e5f69641cd23c50a4542e2ca1eb"}, + {file = "contourpy-1.1.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:407d864db716a067cc696d61fa1ef6637fedf03606e8417fe2aeed20a061e6b2"}, + {file = "contourpy-1.1.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:dfe80c017973e6a4c367e037cb31601044dd55e6bfacd57370674867d15a899b"}, + {file = "contourpy-1.1.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e30aaf2b8a2bac57eb7e1650df1b3a4130e8d0c66fc2f861039d507a11760e1b"}, + {file = "contourpy-1.1.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3de23ca4f381c3770dee6d10ead6fff524d540c0f662e763ad1530bde5112532"}, + {file = "contourpy-1.1.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:566f0e41df06dfef2431defcfaa155f0acfa1ca4acbf8fd80895b1e7e2ada40e"}, + {file = "contourpy-1.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b04c2f0adaf255bf756cf08ebef1be132d3c7a06fe6f9877d55640c5e60c72c5"}, + {file = "contourpy-1.1.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d0c188ae66b772d9d61d43c6030500344c13e3f73a00d1dc241da896f379bb62"}, + {file = "contourpy-1.1.1-cp311-cp311-win32.whl", hash = "sha256:0683e1ae20dc038075d92e0e0148f09ffcefab120e57f6b4c9c0f477ec171f33"}, + {file = "contourpy-1.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:8636cd2fc5da0fb102a2504fa2c4bea3cbc149533b345d72cdf0e7a924decc45"}, + {file = "contourpy-1.1.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:560f1d68a33e89c62da5da4077ba98137a5e4d3a271b29f2f195d0fba2adcb6a"}, + {file = "contourpy-1.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:24216552104ae8f3b34120ef84825400b16eb6133af2e27a190fdc13529f023e"}, + {file = "contourpy-1.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56de98a2fb23025882a18b60c7f0ea2d2d70bbbcfcf878f9067234b1c4818442"}, + {file = "contourpy-1.1.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:07d6f11dfaf80a84c97f1a5ba50d129d9303c5b4206f776e94037332e298dda8"}, + {file = "contourpy-1.1.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f1eaac5257a8f8a047248d60e8f9315c6cff58f7803971170d952555ef6344a7"}, + {file = "contourpy-1.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:19557fa407e70f20bfaba7d55b4d97b14f9480856c4fb65812e8a05fe1c6f9bf"}, + {file = "contourpy-1.1.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:081f3c0880712e40effc5f4c3b08feca6d064cb8cfbb372ca548105b86fd6c3d"}, + {file = "contourpy-1.1.1-cp312-cp312-win32.whl", hash = "sha256:059c3d2a94b930f4dafe8105bcdc1b21de99b30b51b5bce74c753686de858cb6"}, + {file = "contourpy-1.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:f44d78b61740e4e8c71db1cf1fd56d9050a4747681c59ec1094750a658ceb970"}, + {file = "contourpy-1.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:70e5a10f8093d228bb2b552beeb318b8928b8a94763ef03b858ef3612b29395d"}, + {file = "contourpy-1.1.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8394e652925a18ef0091115e3cc191fef350ab6dc3cc417f06da66bf98071ae9"}, + {file = "contourpy-1.1.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5bd5680f844c3ff0008523a71949a3ff5e4953eb7701b28760805bc9bcff217"}, + {file = "contourpy-1.1.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:66544f853bfa85c0d07a68f6c648b2ec81dafd30f272565c37ab47a33b220684"}, + {file = "contourpy-1.1.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e0c02b75acfea5cab07585d25069207e478d12309557f90a61b5a3b4f77f46ce"}, + {file = "contourpy-1.1.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41339b24471c58dc1499e56783fedc1afa4bb018bcd035cfb0ee2ad2a7501ef8"}, + {file = "contourpy-1.1.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:f29fb0b3f1217dfe9362ec55440d0743fe868497359f2cf93293f4b2701b8251"}, + {file = "contourpy-1.1.1-cp38-cp38-win32.whl", hash = "sha256:f9dc7f933975367251c1b34da882c4f0e0b2e24bb35dc906d2f598a40b72bfc7"}, + {file = "contourpy-1.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:498e53573e8b94b1caeb9e62d7c2d053c263ebb6aa259c81050766beb50ff8d9"}, + {file = "contourpy-1.1.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ba42e3810999a0ddd0439e6e5dbf6d034055cdc72b7c5c839f37a7c274cb4eba"}, + {file = "contourpy-1.1.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6c06e4c6e234fcc65435223c7b2a90f286b7f1b2733058bdf1345d218cc59e34"}, + {file = "contourpy-1.1.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca6fab080484e419528e98624fb5c4282148b847e3602dc8dbe0cb0669469887"}, + {file = "contourpy-1.1.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:93df44ab351119d14cd1e6b52a5063d3336f0754b72736cc63db59307dabb718"}, + {file = "contourpy-1.1.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eafbef886566dc1047d7b3d4b14db0d5b7deb99638d8e1be4e23a7c7ac59ff0f"}, + {file = "contourpy-1.1.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:efe0fab26d598e1ec07d72cf03eaeeba8e42b4ecf6b9ccb5a356fde60ff08b85"}, + {file = "contourpy-1.1.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:f08e469821a5e4751c97fcd34bcb586bc243c39c2e39321822060ba902eac49e"}, + {file = "contourpy-1.1.1-cp39-cp39-win32.whl", hash = "sha256:bfc8a5e9238232a45ebc5cb3bfee71f1167064c8d382cadd6076f0d51cff1da0"}, + {file = "contourpy-1.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:c84fdf3da00c2827d634de4fcf17e3e067490c4aea82833625c4c8e6cdea0887"}, + {file = "contourpy-1.1.1-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:229a25f68046c5cf8067d6d6351c8b99e40da11b04d8416bf8d2b1d75922521e"}, + {file = "contourpy-1.1.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a10dab5ea1bd4401c9483450b5b0ba5416be799bbd50fc7a6cc5e2a15e03e8a3"}, + {file = "contourpy-1.1.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:4f9147051cb8fdb29a51dc2482d792b3b23e50f8f57e3720ca2e3d438b7adf23"}, + {file = "contourpy-1.1.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a75cc163a5f4531a256f2c523bd80db509a49fc23721b36dd1ef2f60ff41c3cb"}, + {file = "contourpy-1.1.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b53d5769aa1f2d4ea407c65f2d1d08002952fac1d9e9d307aa2e1023554a163"}, + {file = "contourpy-1.1.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:11b836b7dbfb74e049c302bbf74b4b8f6cb9d0b6ca1bf86cfa8ba144aedadd9c"}, + {file = "contourpy-1.1.1.tar.gz", hash = "sha256:96ba37c2e24b7212a77da85004c38e7c4d155d3e72a45eeaf22c1f03f607e8ab"}, ] [package.dependencies] -numpy = ">=1.20" +numpy = {version = ">=1.16,<2.0", markers = "python_version <= \"3.11\""} [package.extras] bokeh = ["bokeh", "selenium"] docs = ["furo", "sphinx (>=7.2)", "sphinx-copybutton"] -mypy = ["contourpy[bokeh,docs]", "docutils-stubs", "mypy (==1.8.0)", "types-Pillow"] +mypy = ["contourpy[bokeh,docs]", "docutils-stubs", "mypy (==1.4.1)", "types-Pillow"] test = ["Pillow", "contourpy[test-no-images]", "matplotlib"] -test-no-images = ["pytest", "pytest-cov", "pytest-xdist", "wurlitzer"] +test-no-images = ["pytest", "pytest-cov", "wurlitzer"] [[package]] name = "coverage" @@ -271,7 +341,7 @@ tests = ["pytest", "pytest-cov", "pytest-xdist"] name = "decorator" version = "5.1.1" description = "Decorators for Humans" -optional = false +optional = true python-versions = ">=3.5" files = [ {file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"}, @@ -311,7 +381,7 @@ test = ["pytest (>=6)"] name = "executing" version = "2.0.1" description = "Get the currently executing AST node of a frame, and other information" -optional = false +optional = true python-versions = ">=3.5" files = [ {file = "executing-2.0.1-py2.py3-none-any.whl", hash = "sha256:eac49ca94516ccc753f9fb5ce82603156e590b27525a8bc32cce8ae302eb61bc"}, @@ -323,18 +393,18 @@ tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipyth [[package]] name = "filelock" -version = "3.14.0" +version = "3.15.1" description = "A platform independent file lock." optional = true python-versions = ">=3.8" files = [ - {file = "filelock-3.14.0-py3-none-any.whl", hash = "sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f"}, - {file = "filelock-3.14.0.tar.gz", hash = "sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a"}, + {file = "filelock-3.15.1-py3-none-any.whl", hash = "sha256:71b3102950e91dfc1bb4209b64be4dc8854f40e5f534428d8684f953ac847fac"}, + {file = "filelock-3.15.1.tar.gz", hash = "sha256:58a2549afdf9e02e10720eaa4d4470f56386d7a6f72edd7d0596337af8ed7ad8"}, ] [package.extras] docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] typing = ["typing-extensions (>=4.8)"] [[package]] @@ -441,6 +511,43 @@ test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask-expr", "dask[dataframe, test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard"] tqdm = ["tqdm"] +[[package]] +name = "importlib-metadata" +version = "7.1.0" +description = "Read metadata from Python packages" +optional = true +python-versions = ">=3.8" +files = [ + {file = "importlib_metadata-7.1.0-py3-none-any.whl", hash = "sha256:30962b96c0c223483ed6cc7280e7f0199feb01a0e40cfae4d4450fc6fab1f570"}, + {file = "importlib_metadata-7.1.0.tar.gz", hash = "sha256:b78938b926ee8d5f020fc4772d487045805a55ddbad2ecf21c6d60938dc7fcd2"}, +] + +[package.dependencies] +zipp = ">=0.5" + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +perf = ["ipython"] +testing = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] + +[[package]] +name = "importlib-resources" +version = "6.4.0" +description = "Read resources from Python packages" +optional = false +python-versions = ">=3.8" +files = [ + {file = "importlib_resources-6.4.0-py3-none-any.whl", hash = "sha256:50d10f043df931902d4194ea07ec57960f66a80449ff867bfe782b4c486ba78c"}, + {file = "importlib_resources-6.4.0.tar.gz", hash = "sha256:cdb2b453b8046ca4e3798eb1d84f3cce1446a0e8e7b5ef4efb600f19fc398145"}, +] + +[package.dependencies] +zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""} + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] +testing = ["jaraco.test (>=5.4)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)", "zipp (>=3.17)"] + [[package]] name = "iniconfig" version = "2.0.0" @@ -470,7 +577,7 @@ files = [ name = "ipython" version = "8.25.0" description = "IPython: Productive Interactive Computing" -optional = false +optional = true python-versions = ">=3.10" files = [ {file = "ipython-8.25.0-py3-none-any.whl", hash = "sha256:53eee7ad44df903a06655871cbab66d156a051fd86f3ec6750470ac9604ac1ab"}, @@ -520,23 +627,25 @@ colors = ["colorama (>=0.4.6)"] [[package]] name = "jaxtyping" -version = "0.2.29" +version = "0.2.19" description = "Type annotations and runtime checking for shape and dtype of JAX arrays, and PyTrees." optional = true -python-versions = "~=3.9" +python-versions = "~=3.8" files = [ - {file = "jaxtyping-0.2.29-py3-none-any.whl", hash = "sha256:3580fc4dfef4c98ef2372c2c81314d89b98a186eb78d69d925fd0546025d556f"}, - {file = "jaxtyping-0.2.29.tar.gz", hash = "sha256:e1cd916ed0196e40402b0638449e7d051571562b2cd68d8b94961a383faeb409"}, + {file = "jaxtyping-0.2.19-py3-none-any.whl", hash = "sha256:651352032799d422987e783fd1b77699b53c3bb28ffa644bbca5f75ec4fbb843"}, + {file = "jaxtyping-0.2.19.tar.gz", hash = "sha256:21ff4c3caec6781cadfe980b019dde856c1011e17d11dfe8589298040056325a"}, ] [package.dependencies] -typeguard = "2.13.3" +numpy = ">=1.20.0" +typeguard = ">=2.13.3" +typing-extensions = ">=3.7.4.1" [[package]] name = "jedi" version = "0.19.1" description = "An autocompletion tool for Python that can be used for text editors." -optional = false +optional = true python-versions = ">=3.6" files = [ {file = "jedi-0.19.1-py2.py3-none-any.whl", hash = "sha256:e983c654fe5c02867aef4cdfce5a2fbb4a50adc0af145f70504238f18ef5e7e0"}, @@ -729,43 +838,51 @@ files = [ [[package]] name = "libcst" -version = "1.4.0" -description = "A concrete syntax tree with AST-like properties for Python 3.0 through 3.12 programs." +version = "1.1.0" +description = "A concrete syntax tree with AST-like properties for Python 3.5, 3.6, 3.7, 3.8, 3.9, and 3.10 programs." optional = false -python-versions = ">=3.9" -files = [ - {file = "libcst-1.4.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:279b54568ea1f25add50ea4ba3d76d4f5835500c82f24d54daae4c5095b986aa"}, - {file = "libcst-1.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3401dae41fe24565387a65baee3887e31a44e3e58066b0250bc3f3ccf85b1b5a"}, - {file = "libcst-1.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d1989fa12d3cd79118ebd29ebe2a6976d23d509b1a4226bc3d66fcb7cb50bd5d"}, - {file = "libcst-1.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:addc6d585141a7677591868886f6bda0577529401a59d210aa8112114340e129"}, - {file = "libcst-1.4.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:17d71001cb25e94cfe8c3d997095741a8c4aa7a6d234c0f972bc42818c88dfaf"}, - {file = "libcst-1.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:2d47de16d105e7dd5f4e01a428d9f4dc1e71efd74f79766daf54528ce37f23c3"}, - {file = "libcst-1.4.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e6227562fc5c9c1efd15dfe90b0971ae254461b8b6b23c1b617139b6003de1c1"}, - {file = "libcst-1.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3399e6c95df89921511b44d8c5bf6a75bcbc2d51f1f6429763609ba005c10f6b"}, - {file = "libcst-1.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48601e3e590e2d6a7ab8c019cf3937c70511a78d778ab3333764531253acdb33"}, - {file = "libcst-1.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f42797309bb725f0f000510d5463175ccd7155395f09b5e7723971b0007a976d"}, - {file = "libcst-1.4.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cb4e42ea107a37bff7f9fdbee9532d39f9ea77b89caa5c5112b37057b12e0838"}, - {file = "libcst-1.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:9d0cc3c5a2a51fa7e1d579a828c0a2e46b2170024fd8b1a0691c8a52f3abb2d9"}, - {file = "libcst-1.4.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:7ece51d935bc9bf60b528473d2e5cc67cbb88e2f8146297e40ee2c7d80be6f13"}, - {file = "libcst-1.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:81653dea1cdfa4c6520a7c5ffb95fa4d220cbd242e446c7a06d42d8636bfcbba"}, - {file = "libcst-1.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f6abce0e66bba2babfadc20530fd3688f672d565674336595b4623cd800b91ef"}, - {file = "libcst-1.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5da9d7dc83801aba3b8d911f82dc1a375db0d508318bad79d9fb245374afe068"}, - {file = "libcst-1.4.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7c54aa66c86d8ece9c93156a2cf5ca512b0dce40142fe9e072c86af2bf892411"}, - {file = "libcst-1.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:62e2682ee1567b6a89c91853865372bf34f178bfd237853d84df2b87b446e654"}, - {file = "libcst-1.4.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b8ecdba8934632b4dadacb666cd3816627a6ead831b806336972ccc4ba7ca0e9"}, - {file = "libcst-1.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8e54c777b8d27339b70f304d16fc8bc8674ef1bd34ed05ea874bf4921eb5a313"}, - {file = "libcst-1.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:061d6855ef30efe38b8a292b7e5d57c8e820e71fc9ec9846678b60a934b53bbb"}, - {file = "libcst-1.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb0abf627ee14903d05d0ad9b2c6865f1b21eb4081e2c7bea1033f85db2b8bae"}, - {file = "libcst-1.4.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d024f44059a853b4b852cfc04fec33e346659d851371e46fc8e7c19de24d3da9"}, - {file = "libcst-1.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:3c6a8faab9da48c5b371557d0999b4ca51f4f2cbd37ee8c2c4df0ac01c781465"}, - {file = "libcst-1.4.0.tar.gz", hash = "sha256:449e0b16604f054fa7f27c3ffe86ea7ef6c409836fe68fe4e752a1894175db00"}, +python-versions = ">=3.8" +files = [ + {file = "libcst-1.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:63f75656fd733dc20354c46253fde3cf155613e37643c3eaf6f8818e95b7a3d1"}, + {file = "libcst-1.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8ae11eb1ea55a16dc0cdc61b41b29ac347da70fec14cc4381248e141ee2fbe6c"}, + {file = "libcst-1.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4bc745d0c06420fe2644c28d6ddccea9474fb68a2135904043676deb4fa1e6bc"}, + {file = "libcst-1.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c1f2da45f1c45634090fd8672c15e0159fdc46853336686959b2d093b6e10fa"}, + {file = "libcst-1.1.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:003e5e83a12eed23542c4ea20fdc8de830887cc03662432bb36f84f8c4841b81"}, + {file = "libcst-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:3ebbb9732ae3cc4ae7a0e97890bed0a57c11d6df28790c2b9c869f7da653c7c7"}, + {file = "libcst-1.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d68c34e3038d3d1d6324eb47744cbf13f2c65e1214cf49db6ff2a6603c1cd838"}, + {file = "libcst-1.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9dffa1795c2804d183efb01c0f1efd20a7831db6a21a0311edf90b4100d67436"}, + {file = "libcst-1.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc9b6ac36d7ec9db2f053014ea488086ca2ed9c322be104fbe2c71ca759da4bb"}, + {file = "libcst-1.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78b7a38ec4c1c009ac39027d51558b52851fb9234669ba5ba62283185963a31c"}, + {file = "libcst-1.1.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5297a16e575be8173185e936b7765c89a3ca69d4ae217a4af161814a0f9745a7"}, + {file = "libcst-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:7ccaf53925f81118aeaadb068a911fac8abaff608817d7343da280616a5ca9c1"}, + {file = "libcst-1.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:75816647736f7e09c6120bdbf408456f99b248d6272277eed9a58cf50fb8bc7d"}, + {file = "libcst-1.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c8f26250f87ca849a7303ed7a4fd6b2c7ac4dec16b7d7e68ca6a476d7c9bfcdb"}, + {file = "libcst-1.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2d37326bd6f379c64190a28947a586b949de3a76be00176b0732c8ee87d67ebe"}, + {file = "libcst-1.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3d8cf974cfa2487b28f23f56c4bff90d550ef16505e58b0dca0493d5293784b"}, + {file = "libcst-1.1.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82d1271403509b0a4ee6ff7917c2d33b5a015f44d1e208abb1da06ba93b2a378"}, + {file = "libcst-1.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:bca1841693941fdd18371824bb19a9702d5784cd347cb8231317dbdc7062c5bc"}, + {file = "libcst-1.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f36f592e035ef84f312a12b75989dde6a5f6767fe99146cdae6a9ee9aff40dd0"}, + {file = "libcst-1.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f561c9a84eca18be92f4ad90aa9bd873111efbea995449301719a1a7805dbc5c"}, + {file = "libcst-1.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:97fbc73c87e9040e148881041fd5ffa2a6ebf11f64b4ccb5b52e574b95df1a15"}, + {file = "libcst-1.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:99fdc1929703fd9e7408aed2e03f58701c5280b05c8911753a8d8619f7dfdda5"}, + {file = "libcst-1.1.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0bf69cbbab5016d938aac4d3ae70ba9ccb3f90363c588b3b97be434e6ba95403"}, + {file = "libcst-1.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:fe41b33aa73635b1651f64633f429f7aa21f86d2db5748659a99d9b7b1ed2a90"}, + {file = "libcst-1.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:73c086705ed34dbad16c62c9adca4249a556c1b022993d511da70ea85feaf669"}, + {file = "libcst-1.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3a07ecfabbbb8b93209f952a365549e65e658831e9231649f4f4e4263cad24b1"}, + {file = "libcst-1.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c653d9121d6572d8b7f8abf20f88b0a41aab77ff5a6a36e5a0ec0f19af0072e8"}, + {file = "libcst-1.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f1cd308a4c2f71d5e4eec6ee693819933a03b78edb2e4cc5e3ad1afd5fb3f07"}, + {file = "libcst-1.1.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8afb6101b8b3c86c5f9cec6b90ab4da16c3c236fe7396f88e8b93542bb341f7c"}, + {file = "libcst-1.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:d22d1abfe49aa60fc61fa867e10875a9b3024ba5a801112f4d7ba42d8d53242e"}, + {file = "libcst-1.1.0.tar.gz", hash = "sha256:0acbacb9a170455701845b7e940e2d7b9519db35a86768d86330a0b0deae1086"}, ] [package.dependencies] pyyaml = ">=5.2" +typing-extensions = ">=3.7.4.2" +typing-inspect = ">=0.4.0" [package.extras] -dev = ["Sphinx (>=5.1.1)", "black (==23.12.1)", "build (>=0.10.0)", "coverage (>=4.5.4)", "fixit (==2.1.0)", "flake8 (==7.0.0)", "hypothesis (>=4.36.0)", "hypothesmith (>=0.0.4)", "jinja2 (==3.1.4)", "jupyter (>=1.0.0)", "maturin (>=0.8.3,<1.6)", "nbsphinx (>=0.4.2)", "prompt-toolkit (>=2.0.9)", "pyre-check (==0.9.18)", "setuptools-rust (>=1.5.2)", "setuptools-scm (>=6.0.1)", "slotscheck (>=0.7.1)", "sphinx-rtd-theme (>=0.4.3)", "ufmt (==2.6.0)", "usort (==1.0.8.post1)"] +dev = ["Sphinx (>=5.1.1)", "black (==23.9.1)", "build (>=0.10.0)", "coverage (>=4.5.4)", "fixit (==2.0.0.post1)", "flake8 (>=3.7.8,<5)", "hypothesis (>=4.36.0)", "hypothesmith (>=0.0.4)", "jinja2 (==3.1.2)", "jupyter (>=1.0.0)", "maturin (>=0.8.3,<0.16)", "nbsphinx (>=0.4.2)", "prompt-toolkit (>=2.0.9)", "pyre-check (==0.9.18)", "setuptools-rust (>=1.5.2)", "setuptools-scm (>=6.0.1)", "slotscheck (>=0.7.1)", "sphinx-rtd-theme (>=0.4.3)", "ufmt (==2.2.0)", "usort (==1.0.7)"] [[package]] name = "markdown-it-py" @@ -862,61 +979,77 @@ files = [ [[package]] name = "matplotlib" -version = "3.9.0" +version = "3.7.5" description = "Python plotting package" optional = false -python-versions = ">=3.9" -files = [ - {file = "matplotlib-3.9.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2bcee1dffaf60fe7656183ac2190bd630842ff87b3153afb3e384d966b57fe56"}, - {file = "matplotlib-3.9.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3f988bafb0fa39d1074ddd5bacd958c853e11def40800c5824556eb630f94d3b"}, - {file = "matplotlib-3.9.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fe428e191ea016bb278758c8ee82a8129c51d81d8c4bc0846c09e7e8e9057241"}, - {file = "matplotlib-3.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaf3978060a106fab40c328778b148f590e27f6fa3cd15a19d6892575bce387d"}, - {file = "matplotlib-3.9.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2e7f03e5cbbfacdd48c8ea394d365d91ee8f3cae7e6ec611409927b5ed997ee4"}, - {file = "matplotlib-3.9.0-cp310-cp310-win_amd64.whl", hash = "sha256:13beb4840317d45ffd4183a778685e215939be7b08616f431c7795276e067463"}, - {file = "matplotlib-3.9.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:063af8587fceeac13b0936c42a2b6c732c2ab1c98d38abc3337e430e1ff75e38"}, - {file = "matplotlib-3.9.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9a2fa6d899e17ddca6d6526cf6e7ba677738bf2a6a9590d702c277204a7c6152"}, - {file = "matplotlib-3.9.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:550cdda3adbd596078cca7d13ed50b77879104e2e46392dcd7c75259d8f00e85"}, - {file = "matplotlib-3.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76cce0f31b351e3551d1f3779420cf8f6ec0d4a8cf9c0237a3b549fd28eb4abb"}, - {file = "matplotlib-3.9.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c53aeb514ccbbcbab55a27f912d79ea30ab21ee0531ee2c09f13800efb272674"}, - {file = "matplotlib-3.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:a5be985db2596d761cdf0c2eaf52396f26e6a64ab46bd8cd810c48972349d1be"}, - {file = "matplotlib-3.9.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:c79f3a585f1368da6049318bdf1f85568d8d04b2e89fc24b7e02cc9b62017382"}, - {file = "matplotlib-3.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bdd1ecbe268eb3e7653e04f451635f0fb0f77f07fd070242b44c076c9106da84"}, - {file = "matplotlib-3.9.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d38e85a1a6d732f645f1403ce5e6727fd9418cd4574521d5803d3d94911038e5"}, - {file = "matplotlib-3.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a490715b3b9984fa609116481b22178348c1a220a4499cda79132000a79b4db"}, - {file = "matplotlib-3.9.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8146ce83cbc5dc71c223a74a1996d446cd35cfb6a04b683e1446b7e6c73603b7"}, - {file = "matplotlib-3.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:d91a4ffc587bacf5c4ce4ecfe4bcd23a4b675e76315f2866e588686cc97fccdf"}, - {file = "matplotlib-3.9.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:616fabf4981a3b3c5a15cd95eba359c8489c4e20e03717aea42866d8d0465956"}, - {file = "matplotlib-3.9.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:cd53c79fd02f1c1808d2cfc87dd3cf4dbc63c5244a58ee7944497107469c8d8a"}, - {file = "matplotlib-3.9.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:06a478f0d67636554fa78558cfbcd7b9dba85b51f5c3b5a0c9be49010cf5f321"}, - {file = "matplotlib-3.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:81c40af649d19c85f8073e25e5806926986806fa6d54be506fbf02aef47d5a89"}, - {file = "matplotlib-3.9.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:52146fc3bd7813cc784562cb93a15788be0b2875c4655e2cc6ea646bfa30344b"}, - {file = "matplotlib-3.9.0-cp39-cp39-win_amd64.whl", hash = "sha256:0fc51eaa5262553868461c083d9adadb11a6017315f3a757fc45ec6ec5f02888"}, - {file = "matplotlib-3.9.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:bd4f2831168afac55b881db82a7730992aa41c4f007f1913465fb182d6fb20c0"}, - {file = "matplotlib-3.9.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:290d304e59be2b33ef5c2d768d0237f5bd132986bdcc66f80bc9bcc300066a03"}, - {file = "matplotlib-3.9.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ff2e239c26be4f24bfa45860c20ffccd118d270c5b5d081fa4ea409b5469fcd"}, - {file = "matplotlib-3.9.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:af4001b7cae70f7eaacfb063db605280058246de590fa7874f00f62259f2df7e"}, - {file = "matplotlib-3.9.0.tar.gz", hash = "sha256:e6d29ea6c19e34b30fb7d88b7081f869a03014f66fe06d62cc77d5a6ea88ed7a"}, +python-versions = ">=3.8" +files = [ + {file = "matplotlib-3.7.5-cp310-cp310-macosx_10_12_universal2.whl", hash = "sha256:4a87b69cb1cb20943010f63feb0b2901c17a3b435f75349fd9865713bfa63925"}, + {file = "matplotlib-3.7.5-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:d3ce45010fefb028359accebb852ca0c21bd77ec0f281952831d235228f15810"}, + {file = "matplotlib-3.7.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fbea1e762b28400393d71be1a02144aa16692a3c4c676ba0178ce83fc2928fdd"}, + {file = "matplotlib-3.7.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec0e1adc0ad70ba8227e957551e25a9d2995e319c29f94a97575bb90fa1d4469"}, + {file = "matplotlib-3.7.5-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6738c89a635ced486c8a20e20111d33f6398a9cbebce1ced59c211e12cd61455"}, + {file = "matplotlib-3.7.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1210b7919b4ed94b5573870f316bca26de3e3b07ffdb563e79327dc0e6bba515"}, + {file = "matplotlib-3.7.5-cp310-cp310-win32.whl", hash = "sha256:068ebcc59c072781d9dcdb82f0d3f1458271c2de7ca9c78f5bd672141091e9e1"}, + {file = "matplotlib-3.7.5-cp310-cp310-win_amd64.whl", hash = "sha256:f098ffbaab9df1e3ef04e5a5586a1e6b1791380698e84938d8640961c79b1fc0"}, + {file = "matplotlib-3.7.5-cp311-cp311-macosx_10_12_universal2.whl", hash = "sha256:f65342c147572673f02a4abec2d5a23ad9c3898167df9b47c149f32ce61ca078"}, + {file = "matplotlib-3.7.5-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:4ddf7fc0e0dc553891a117aa083039088d8a07686d4c93fb8a810adca68810af"}, + {file = "matplotlib-3.7.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0ccb830fc29442360d91be48527809f23a5dcaee8da5f4d9b2d5b867c1b087b8"}, + {file = "matplotlib-3.7.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:efc6bb28178e844d1f408dd4d6341ee8a2e906fc9e0fa3dae497da4e0cab775d"}, + {file = "matplotlib-3.7.5-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3b15c4c2d374f249f324f46e883340d494c01768dd5287f8bc00b65b625ab56c"}, + {file = "matplotlib-3.7.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d028555421912307845e59e3de328260b26d055c5dac9b182cc9783854e98fb"}, + {file = "matplotlib-3.7.5-cp311-cp311-win32.whl", hash = "sha256:fe184b4625b4052fa88ef350b815559dd90cc6cc8e97b62f966e1ca84074aafa"}, + {file = "matplotlib-3.7.5-cp311-cp311-win_amd64.whl", hash = "sha256:084f1f0f2f1010868c6f1f50b4e1c6f2fb201c58475494f1e5b66fed66093647"}, + {file = "matplotlib-3.7.5-cp312-cp312-macosx_10_12_universal2.whl", hash = "sha256:34bceb9d8ddb142055ff27cd7135f539f2f01be2ce0bafbace4117abe58f8fe4"}, + {file = "matplotlib-3.7.5-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:c5a2134162273eb8cdfd320ae907bf84d171de948e62180fa372a3ca7cf0f433"}, + {file = "matplotlib-3.7.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:039ad54683a814002ff37bf7981aa1faa40b91f4ff84149beb53d1eb64617980"}, + {file = "matplotlib-3.7.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4d742ccd1b09e863b4ca58291728db645b51dab343eebb08d5d4b31b308296ce"}, + {file = "matplotlib-3.7.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:743b1c488ca6a2bc7f56079d282e44d236bf375968bfd1b7ba701fd4d0fa32d6"}, + {file = "matplotlib-3.7.5-cp312-cp312-win_amd64.whl", hash = "sha256:fbf730fca3e1f23713bc1fae0a57db386e39dc81ea57dc305c67f628c1d7a342"}, + {file = "matplotlib-3.7.5-cp38-cp38-macosx_10_12_universal2.whl", hash = "sha256:cfff9b838531698ee40e40ea1a8a9dc2c01edb400b27d38de6ba44c1f9a8e3d2"}, + {file = "matplotlib-3.7.5-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:1dbcca4508bca7847fe2d64a05b237a3dcaec1f959aedb756d5b1c67b770c5ee"}, + {file = "matplotlib-3.7.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4cdf4ef46c2a1609a50411b66940b31778db1e4b73d4ecc2eaa40bd588979b13"}, + {file = "matplotlib-3.7.5-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:167200ccfefd1674b60e957186dfd9baf58b324562ad1a28e5d0a6b3bea77905"}, + {file = "matplotlib-3.7.5-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:53e64522934df6e1818b25fd48cf3b645b11740d78e6ef765fbb5fa5ce080d02"}, + {file = "matplotlib-3.7.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d3e3bc79b2d7d615067bd010caff9243ead1fc95cf735c16e4b2583173f717eb"}, + {file = "matplotlib-3.7.5-cp38-cp38-win32.whl", hash = "sha256:6b641b48c6819726ed47c55835cdd330e53747d4efff574109fd79b2d8a13748"}, + {file = "matplotlib-3.7.5-cp38-cp38-win_amd64.whl", hash = "sha256:f0b60993ed3488b4532ec6b697059897891927cbfc2b8d458a891b60ec03d9d7"}, + {file = "matplotlib-3.7.5-cp39-cp39-macosx_10_12_universal2.whl", hash = "sha256:090964d0afaff9c90e4d8de7836757e72ecfb252fb02884016d809239f715651"}, + {file = "matplotlib-3.7.5-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:9fc6fcfbc55cd719bc0bfa60bde248eb68cf43876d4c22864603bdd23962ba25"}, + {file = "matplotlib-3.7.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5e7cc3078b019bb863752b8b60e8b269423000f1603cb2299608231996bd9d54"}, + {file = "matplotlib-3.7.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1e4e9a868e8163abaaa8259842d85f949a919e1ead17644fb77a60427c90473c"}, + {file = "matplotlib-3.7.5-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fa7ebc995a7d747dacf0a717d0eb3aa0f0c6a0e9ea88b0194d3a3cd241a1500f"}, + {file = "matplotlib-3.7.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3785bfd83b05fc0e0c2ae4c4a90034fe693ef96c679634756c50fe6efcc09856"}, + {file = "matplotlib-3.7.5-cp39-cp39-win32.whl", hash = "sha256:29b058738c104d0ca8806395f1c9089dfe4d4f0f78ea765c6c704469f3fffc81"}, + {file = "matplotlib-3.7.5-cp39-cp39-win_amd64.whl", hash = "sha256:fd4028d570fa4b31b7b165d4a685942ae9cdc669f33741e388c01857d9723eab"}, + {file = "matplotlib-3.7.5-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:2a9a3f4d6a7f88a62a6a18c7e6a84aedcaf4faf0708b4ca46d87b19f1b526f88"}, + {file = "matplotlib-3.7.5-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b9b3fd853d4a7f008a938df909b96db0b454225f935d3917520305b90680579c"}, + {file = "matplotlib-3.7.5-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f0ad550da9f160737d7890217c5eeed4337d07e83ca1b2ca6535078f354e7675"}, + {file = "matplotlib-3.7.5-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:20da7924a08306a861b3f2d1da0d1aa9a6678e480cf8eacffe18b565af2813e7"}, + {file = "matplotlib-3.7.5-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b45c9798ea6bb920cb77eb7306409756a7fab9db9b463e462618e0559aecb30e"}, + {file = "matplotlib-3.7.5-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a99866267da1e561c7776fe12bf4442174b79aac1a47bd7e627c7e4d077ebd83"}, + {file = "matplotlib-3.7.5-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2b6aa62adb6c268fc87d80f963aca39c64615c31830b02697743c95590ce3fbb"}, + {file = "matplotlib-3.7.5-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:e530ab6a0afd082d2e9c17eb1eb064a63c5b09bb607b2b74fa41adbe3e162286"}, + {file = "matplotlib-3.7.5.tar.gz", hash = "sha256:1e5c971558ebc811aa07f54c7b7c677d78aa518ef4c390e14673a09e0860184a"}, ] [package.dependencies] contourpy = ">=1.0.1" cycler = ">=0.10" fonttools = ">=4.22.0" -kiwisolver = ">=1.3.1" -numpy = ">=1.23" +importlib-resources = {version = ">=3.2.0", markers = "python_version < \"3.10\""} +kiwisolver = ">=1.0.1" +numpy = ">=1.20,<2" packaging = ">=20.0" -pillow = ">=8" +pillow = ">=6.2.0" pyparsing = ">=2.3.1" python-dateutil = ">=2.7" -[package.extras] -dev = ["meson-python (>=0.13.1)", "numpy (>=1.25)", "pybind11 (>=2.6)", "setuptools (>=64)", "setuptools_scm (>=7)"] - [[package]] name = "matplotlib-inline" version = "0.1.7" description = "Inline Matplotlib backend for Jupyter" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "matplotlib_inline-0.1.7-py3-none-any.whl", hash = "sha256:df192d39a4ff8f21b1895d72e6a13f5fcc5099f00fa84384e0ea28c2cc0653ca"}, @@ -1043,65 +1176,57 @@ files = [ [[package]] name = "networkx" -version = "3.3" +version = "3.1" description = "Python package for creating and manipulating graphs and networks" optional = true -python-versions = ">=3.10" +python-versions = ">=3.8" files = [ - {file = "networkx-3.3-py3-none-any.whl", hash = "sha256:28575580c6ebdaf4505b22c6256a2b9de86b316dc63ba9e93abde3d78dfdbcf2"}, - {file = "networkx-3.3.tar.gz", hash = "sha256:0c127d8b2f4865f59ae9cb8aafcd60b5c70f3241ebd66f7defad7c4ab90126c9"}, + {file = "networkx-3.1-py3-none-any.whl", hash = "sha256:4f33f68cb2afcf86f28a45f43efc27a9386b535d567d2127f8f61d51dec58d36"}, + {file = "networkx-3.1.tar.gz", hash = "sha256:de346335408f84de0eada6ff9fafafff9bcda11f0a0dfaa931133debb146ab61"}, ] [package.extras] -default = ["matplotlib (>=3.6)", "numpy (>=1.23)", "pandas (>=1.4)", "scipy (>=1.9,!=1.11.0,!=1.11.1)"] -developer = ["changelist (==0.5)", "mypy (>=1.1)", "pre-commit (>=3.2)", "rtoml"] -doc = ["myst-nb (>=1.0)", "numpydoc (>=1.7)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.14)", "sphinx (>=7)", "sphinx-gallery (>=0.14)", "texext (>=0.6.7)"] -extra = ["lxml (>=4.6)", "pydot (>=2.0)", "pygraphviz (>=1.12)", "sympy (>=1.10)"] -test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"] +default = ["matplotlib (>=3.4)", "numpy (>=1.20)", "pandas (>=1.3)", "scipy (>=1.8)"] +developer = ["mypy (>=1.1)", "pre-commit (>=3.2)"] +doc = ["nb2plots (>=0.6)", "numpydoc (>=1.5)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.13)", "sphinx (>=6.1)", "sphinx-gallery (>=0.12)", "texext (>=0.6.7)"] +extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.10)", "sympy (>=1.10)"] +test = ["codecov (>=2.1)", "pytest (>=7.2)", "pytest-cov (>=4.0)"] [[package]] name = "numpy" -version = "1.26.4" +version = "1.24.4" description = "Fundamental package for array computing in Python" optional = false -python-versions = ">=3.9" -files = [ - {file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"}, - {file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"}, - {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"}, - {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"}, - {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"}, - {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"}, - {file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"}, - {file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"}, - {file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"}, - {file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"}, - {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"}, - {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"}, - {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"}, - {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"}, - {file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"}, - {file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"}, - {file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"}, - {file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"}, - {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"}, - {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"}, - {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"}, - {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"}, - {file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"}, - {file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"}, - {file = "numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c"}, - {file = "numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be"}, - {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764"}, - {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3"}, - {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd"}, - {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c"}, - {file = "numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6"}, - {file = "numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea"}, - {file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"}, - {file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"}, - {file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"}, - {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"}, +python-versions = ">=3.8" +files = [ + {file = "numpy-1.24.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0bfb52d2169d58c1cdb8cc1f16989101639b34c7d3ce60ed70b19c63eba0b64"}, + {file = "numpy-1.24.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ed094d4f0c177b1b8e7aa9cba7d6ceed51c0e569a5318ac0ca9a090680a6a1b1"}, + {file = "numpy-1.24.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79fc682a374c4a8ed08b331bef9c5f582585d1048fa6d80bc6c35bc384eee9b4"}, + {file = "numpy-1.24.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ffe43c74893dbf38c2b0a1f5428760a1a9c98285553c89e12d70a96a7f3a4d6"}, + {file = "numpy-1.24.4-cp310-cp310-win32.whl", hash = "sha256:4c21decb6ea94057331e111a5bed9a79d335658c27ce2adb580fb4d54f2ad9bc"}, + {file = "numpy-1.24.4-cp310-cp310-win_amd64.whl", hash = "sha256:b4bea75e47d9586d31e892a7401f76e909712a0fd510f58f5337bea9572c571e"}, + {file = "numpy-1.24.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f136bab9c2cfd8da131132c2cf6cc27331dd6fae65f95f69dcd4ae3c3639c810"}, + {file = "numpy-1.24.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e2926dac25b313635e4d6cf4dc4e51c8c0ebfed60b801c799ffc4c32bf3d1254"}, + {file = "numpy-1.24.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:222e40d0e2548690405b0b3c7b21d1169117391c2e82c378467ef9ab4c8f0da7"}, + {file = "numpy-1.24.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7215847ce88a85ce39baf9e89070cb860c98fdddacbaa6c0da3ffb31b3350bd5"}, + {file = "numpy-1.24.4-cp311-cp311-win32.whl", hash = "sha256:4979217d7de511a8d57f4b4b5b2b965f707768440c17cb70fbf254c4b225238d"}, + {file = "numpy-1.24.4-cp311-cp311-win_amd64.whl", hash = "sha256:b7b1fc9864d7d39e28f41d089bfd6353cb5f27ecd9905348c24187a768c79694"}, + {file = "numpy-1.24.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1452241c290f3e2a312c137a9999cdbf63f78864d63c79039bda65ee86943f61"}, + {file = "numpy-1.24.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:04640dab83f7c6c85abf9cd729c5b65f1ebd0ccf9de90b270cd61935eef0197f"}, + {file = "numpy-1.24.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5425b114831d1e77e4b5d812b69d11d962e104095a5b9c3b641a218abcc050e"}, + {file = "numpy-1.24.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd80e219fd4c71fc3699fc1dadac5dcf4fd882bfc6f7ec53d30fa197b8ee22dc"}, + {file = "numpy-1.24.4-cp38-cp38-win32.whl", hash = "sha256:4602244f345453db537be5314d3983dbf5834a9701b7723ec28923e2889e0bb2"}, + {file = "numpy-1.24.4-cp38-cp38-win_amd64.whl", hash = "sha256:692f2e0f55794943c5bfff12b3f56f99af76f902fc47487bdfe97856de51a706"}, + {file = "numpy-1.24.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2541312fbf09977f3b3ad449c4e5f4bb55d0dbf79226d7724211acc905049400"}, + {file = "numpy-1.24.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9667575fb6d13c95f1b36aca12c5ee3356bf001b714fc354eb5465ce1609e62f"}, + {file = "numpy-1.24.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3a86ed21e4f87050382c7bc96571755193c4c1392490744ac73d660e8f564a9"}, + {file = "numpy-1.24.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d11efb4dbecbdf22508d55e48d9c8384db795e1b7b51ea735289ff96613ff74d"}, + {file = "numpy-1.24.4-cp39-cp39-win32.whl", hash = "sha256:6620c0acd41dbcb368610bb2f4d83145674040025e5536954782467100aa8835"}, + {file = "numpy-1.24.4-cp39-cp39-win_amd64.whl", hash = "sha256:befe2bf740fd8373cf56149a5c23a0f601e82869598d41f8e188a0e9869926f8"}, + {file = "numpy-1.24.4-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:31f13e25b4e304632a4619d0e0777662c2ffea99fcae2029556b17d8ff958aef"}, + {file = "numpy-1.24.4-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95f7ac6540e95bc440ad77f56e520da5bf877f87dca58bd095288dce8940532a"}, + {file = "numpy-1.24.4-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e98f220aa76ca2a977fe435f5b04d7b3470c0a2e6312907b37ba6068f26787f2"}, + {file = "numpy-1.24.4.tar.gz", hash = "sha256:80f5e3a4e498641401868df4208b74581206afbee7cf7b8329daae82676d9463"}, ] [[package]] @@ -1261,7 +1386,7 @@ files = [ name = "parso" version = "0.8.4" description = "A Python Parser" -optional = false +optional = true python-versions = ">=3.6" files = [ {file = "parso-0.8.4-py2.py3-none-any.whl", hash = "sha256:a418670a20291dacd2dddc80c377c5c3791378ee1e8d12bffc35420643d43f18"}, @@ -1287,7 +1412,7 @@ files = [ name = "pexpect" version = "4.9.0" description = "Pexpect allows easy control of interactive console applications." -optional = false +optional = true python-versions = "*" files = [ {file = "pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523"}, @@ -1418,7 +1543,7 @@ testing = ["pytest", "pytest-benchmark"] name = "prompt-toolkit" version = "3.0.47" description = "Library for building powerful interactive command lines in Python" -optional = false +optional = true python-versions = ">=3.7.0" files = [ {file = "prompt_toolkit-3.0.47-py3-none-any.whl", hash = "sha256:0d7bfa67001d5e39d02c224b663abc33687405033a8c422d0d675a5a13361d10"}, @@ -1432,7 +1557,7 @@ wcwidth = "*" name = "ptyprocess" version = "0.7.0" description = "Run a subprocess in a pseudo terminal" -optional = false +optional = true python-versions = "*" files = [ {file = "ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35"}, @@ -1443,7 +1568,7 @@ files = [ name = "pure-eval" version = "0.2.2" description = "Safely evaluate AST nodes without side effects" -optional = false +optional = true python-versions = "*" files = [ {file = "pure_eval-0.2.2-py3-none-any.whl", hash = "sha256:01eaab343580944bc56080ebe0a674b39ec44a945e6d09ba7db3cb8cec289350"}, @@ -1508,6 +1633,7 @@ mccabe = ">=0.6,<0.8" platformdirs = ">=2.2.0" tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} tomlkit = ">=0.10.1" +typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""} [package.extras] spelling = ["pyenchant (>=3.2,<4.0)"] @@ -1654,6 +1780,7 @@ files = [ [package.dependencies] markdown-it-py = ">=2.2.0" pygments = ">=2.13.0,<3.0.0" +typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.9\""} [package.extras] jupyter = ["ipywidgets (>=7.5.1,<9)"] @@ -1684,7 +1811,7 @@ files = [ name = "stack-data" version = "0.6.3" description = "Extract data from python stack frames and tracebacks for informative displays" -optional = false +optional = true python-versions = "*" files = [ {file = "stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695"}, @@ -1806,7 +1933,7 @@ optree = ["optree (>=0.9.1)"] name = "traitlets" version = "5.14.3" description = "Traitlets Python configuration system" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f"}, @@ -1842,18 +1969,22 @@ tutorials = ["matplotlib", "pandas", "tabulate", "torch"] [[package]] name = "typeguard" -version = "2.13.3" +version = "4.3.0" description = "Run-time type checker for Python" optional = true -python-versions = ">=3.5.3" +python-versions = ">=3.8" files = [ - {file = "typeguard-2.13.3-py3-none-any.whl", hash = "sha256:5e3e3be01e887e7eafae5af63d1f36c849aaa94e3a0112097312aabfa16284f1"}, - {file = "typeguard-2.13.3.tar.gz", hash = "sha256:00edaa8da3a133674796cf5ea87d9f4b4c367d77476e185e80251cc13dfbb8c4"}, + {file = "typeguard-4.3.0-py3-none-any.whl", hash = "sha256:4d24c5b39a117f8a895b9da7a9b3114f04eb63bade45a4492de49b175b6f7dfa"}, + {file = "typeguard-4.3.0.tar.gz", hash = "sha256:92ee6a0aec9135181eae6067ebd617fd9de8d75d714fb548728a4933b1dea651"}, ] +[package.dependencies] +importlib-metadata = {version = ">=3.6", markers = "python_version < \"3.10\""} +typing-extensions = ">=4.10.0" + [package.extras] -doc = ["sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] -test = ["mypy", "pytest", "typing-extensions"] +doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme (>=1.3.0)"] +test = ["coverage[toml] (>=7)", "mypy (>=1.2.0)", "pytest (>=7)"] [[package]] name = "typer" @@ -1883,11 +2014,26 @@ files = [ {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] +[[package]] +name = "typing-inspect" +version = "0.9.0" +description = "Runtime inspection utilities for typing module." +optional = false +python-versions = "*" +files = [ + {file = "typing_inspect-0.9.0-py3-none-any.whl", hash = "sha256:9ee6fc59062311ef8547596ab6b955e1b8aa46242d854bfc78f4f6b0eff35f9f"}, + {file = "typing_inspect-0.9.0.tar.gz", hash = "sha256:b23fc42ff6f6ef6954e4852c1fb512cdd18dbea03134f91f856a95ccc9461f78"}, +] + +[package.dependencies] +mypy-extensions = ">=0.3.0" +typing-extensions = ">=3.7.4" + [[package]] name = "wcwidth" version = "0.2.13" description = "Measures the displayed width of unicode strings in a terminal" -optional = false +optional = true python-versions = "*" files = [ {file = "wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859"}, @@ -1973,10 +2119,26 @@ files = [ {file = "wrapt-1.16.0.tar.gz", hash = "sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d"}, ] +[[package]] +name = "zipp" +version = "3.19.2" +description = "Backport of pathlib-compatible object wrapper for zip files" +optional = false +python-versions = ">=3.8" +files = [ + {file = "zipp-3.19.2-py3-none-any.whl", hash = "sha256:f091755f667055f2d02b32c53771a7a6c8b47e1fdbc4b72a8b9072b3eef8015c"}, + {file = "zipp-3.19.2.tar.gz", hash = "sha256:bf1dcf6450f873a13e952a29504887c89e6de7506209e5b1bcc3460135d4de19"}, +] + +[package.extras] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] + [extras] array = ["jaxtyping", "numpy", "torch"] +notebook = ["ipython"] [metadata] lock-version = "2.0" -python-versions = "^3.10" -content-hash = "b103697ab0f6424f1007107b22e004c99ebe010859e44340308feaa642bc9824" +python-versions = "^3.8" +content-hash = "dcc7222c78e1d3797c2b80513b86c27c44394206e1f98af1e5943c4b3f506cab" diff --git a/pyproject.toml b/pyproject.toml index 40597bf0..681a3a46 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,9 +22,11 @@ python = "^3.8" numpy = { version = "^1.22.4", optional = true } torch = { version = ">=1.13.1", optional = true } jaxtyping = { version = "^0.2.12", optional = true } +ipython = { version = "^8.20.0", optional = true, python = "^3.10" } [tool.poetry.extras] array = ["numpy", "torch", "jaxtyping"] +notebook = ["ipython"] [tool.poetry.group.dev.dependencies] pytest = "^7.2.1" @@ -36,7 +38,6 @@ mypy = "^1.0.1" pytest-cov = "^4.1.0" coverage-badge = "^1.1.0" matplotlib = "^3.0.0" -ipython = "^8.20.0" [build-system] requires = ["poetry-core"] From 18c578621ce24da85772ebfb91a1d2356d1e1e7c Mon Sep 17 00:00:00 2001 From: mivanit Date: Mon, 17 Jun 2024 17:08:22 -0700 Subject: [PATCH 004/158] add py 3.12 to tests --- .github/workflows/checks.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index c1f13d3c..77d9f635 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -37,9 +37,11 @@ jobs: - python: "3.10" torch: "1.13.1" - python: "3.10" - torch: "2.0.1" + torch: "2.3.1" - python: "3.11" - torch: "2.0.1" + torch: "2.3.1" + - python: "3.12" + torch: "2.3.1" steps: - name: Checkout code uses: actions/checkout@v2 From efe6b662c4044770c53cade55138b0cb31426983 Mon Sep 17 00:00:00 2001 From: mivanit Date: Mon, 17 Jun 2024 17:08:30 -0700 Subject: [PATCH 005/158] some typing fixes --- muutils/json_serialize/array.py | 1 + muutils/json_serialize/json_serialize.py | 9 ++++---- .../json_serialize/serializable_dataclass.py | 16 +++++++------ muutils/json_serialize/util.py | 23 ++++++++++++------- muutils/mlutils.py | 7 +++--- muutils/nbutils/configure_notebook.py | 1 + 6 files changed, 35 insertions(+), 22 deletions(-) diff --git a/muutils/json_serialize/array.py b/muutils/json_serialize/array.py index aef5a97e..8130d160 100644 --- a/muutils/json_serialize/array.py +++ b/muutils/json_serialize/array.py @@ -1,3 +1,4 @@ +from __future__ import annotations import typing import warnings from typing import Any, Iterable, Literal, Optional, Sequence diff --git a/muutils/json_serialize/json_serialize.py b/muutils/json_serialize/json_serialize.py index a857765f..b6097e0d 100644 --- a/muutils/json_serialize/json_serialize.py +++ b/muutils/json_serialize/json_serialize.py @@ -1,9 +1,10 @@ +from __future__ import annotations import inspect import types import warnings from dataclasses import dataclass, is_dataclass from pathlib import Path -from typing import Any, Callable, Iterable, Mapping +from typing import Any, Callable, Iterable, Mapping, Dict, Set, Union try: from muutils.json_serialize.array import ArrayMode, serialize_array @@ -39,7 +40,7 @@ "__annotations__", ) -SERIALIZER_SPECIAL_FUNCS: dict[str, Callable] = { +SERIALIZER_SPECIAL_FUNCS: Dict[str, Callable] = { "str": str, "dir": dir, "type": try_catch(lambda x: str(type(x).__name__)), @@ -48,12 +49,12 @@ "sourcefile": try_catch(lambda x: inspect.getsourcefile(x)), } -SERIALIZE_DIRECT_AS_STR: set[str] = { +SERIALIZE_DIRECT_AS_STR: Set[str] = { "", "", } -ObjectPath = MonoTuple[str | int] +ObjectPath = MonoTuple[Union[str,int]] @dataclass diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index 0772f08f..8ffca827 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -1,10 +1,11 @@ +from __future__ import annotations import abc import dataclasses import json import types import typing import warnings -from typing import Any, Callable, Optional, Type, TypeVar +from typing import Any, Callable, Optional, Type, TypeVar, Union # pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access @@ -34,17 +35,18 @@ class SerializableField(dataclasses.Field): def __init__( self, - default: Any | dataclasses._MISSING_TYPE = dataclasses.MISSING, - default_factory: ( - Callable[[], Any] | dataclasses._MISSING_TYPE - ) = dataclasses.MISSING, + default: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING, + default_factory: Union[ + Callable[[], Any], + dataclasses._MISSING_TYPE + ] = dataclasses.MISSING, init: bool = True, repr: bool = True, hash: Optional[bool] = None, compare: bool = True, # TODO: add field for custom comparator (such as serializing) - metadata: types.MappingProxyType | None = None, - kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING, + metadata: Optional[types.MappingProxyType] = None, + kw_only: Union[bool, dataclasses._MISSING_TYPE] = dataclasses.MISSING, serialize: bool = True, serialization_fn: Optional[Callable[[Any], Any]] = None, loading_fn: Optional[Callable[[Any], Any]] = None, diff --git a/muutils/json_serialize/util.py b/muutils/json_serialize/util.py index 81c7b06a..eedcde61 100644 --- a/muutils/json_serialize/util.py +++ b/muutils/json_serialize/util.py @@ -1,9 +1,11 @@ +from __future__ import annotations import functools import inspect +import sys import types import typing import warnings -from typing import Any, Callable, Iterable, Literal, Union +from typing import Any, Callable, Iterable, Literal, Union, Dict _NUMPY_WORKING: bool try: @@ -16,11 +18,12 @@ TypeErrorMode = Union[ErrorMode, Literal["try_convert"]] -JSONitem = Union[bool, int, float, str, list, dict[str, Any], None] -JSONdict = dict[str, JSONitem] +JSONitem = Union[bool, int, float, str, list, Dict[str, Any], None] +JSONdict = Dict[str, JSONitem] Hashableitem = Union[bool, int, float, str, tuple] -if typing.TYPE_CHECKING: +# or if python version <3.9 +if typing.TYPE_CHECKING or sys.version_info[1] < 9: MonoTuple = typing.Sequence else: @@ -38,17 +41,21 @@ def __init_subclass__(cls, *args, **kwargs): # idk why mypy thinks there is no such function in typing @typing._tp_cache # type: ignore def __class_getitem__(cls, params): - if isinstance(params, (type, types.UnionType)): - return types.GenericAlias(tuple, (params, Ellipsis)) + if isinstance(params, type): + typing.GenericAlias(tuple, (params, Ellipsis)) + elif any("typing.UnionType" in str(t) for t in params.mro()): + # TODO: unsure about this + # check via mro + return typing.GenericAlias(tuple, (params, Ellipsis)) # test if has len and is iterable elif isinstance(params, Iterable): if len(params) == 0: return tuple elif len(params) == 1: - return types.GenericAlias(tuple, (params[0], Ellipsis)) + return typing.GenericAlias(tuple, (params[0], Ellipsis)) else: raise TypeError( - f"MonoTuple expects 1 type argument, got {len(params) = } \n\t{params = }" + f"MonoTuple expects 1 type argument, got {params = }" ) diff --git a/muutils/mlutils.py b/muutils/mlutils.py index 6f21a85c..f094ce7d 100644 --- a/muutils/mlutils.py +++ b/muutils/mlutils.py @@ -1,3 +1,4 @@ +from __future__ import annotations import json import os import random @@ -5,7 +6,7 @@ import warnings from itertools import islice from pathlib import Path -from typing import Any, Callable, TypeVar +from typing import Any, Callable, TypeVar, Union, Optional ARRAY_IMPORTS: bool try: @@ -23,7 +24,7 @@ GLOBAL_SEED: int = DEFAULT_SEED -def get_device(device: "str|torch.device|None" = None) -> "torch.device": +def get_device(device: "Union[str,torch.device,None]" = None) -> "torch.device": """Get the torch.device instance on which `torch.Tensor`s should be allocated.""" if not ARRAY_IMPORTS: raise ImportError( @@ -130,7 +131,7 @@ def get_checkpoint_paths_for_run( def register_method( method_dict: dict[str, Callable[..., Any]], - custom_name: str | None = None, + custom_name: Optional[str] = None, ) -> Callable[[F], F]: """Decorator to add a method to the method_dict""" diff --git a/muutils/nbutils/configure_notebook.py b/muutils/nbutils/configure_notebook.py index ab77e37d..7cc80b00 100644 --- a/muutils/nbutils/configure_notebook.py +++ b/muutils/nbutils/configure_notebook.py @@ -1,3 +1,4 @@ +from __future__ import annotations import os import typing import warnings From 1a762a8df978b1f52713ad2a3c4aa7dd3d294124 Mon Sep 17 00:00:00 2001 From: mivanit Date: Mon, 17 Jun 2024 17:18:51 -0700 Subject: [PATCH 006/158] wip typing fixes --- muutils/dictmagic.py | 103 +++++++++--------- .../json_serialize/serializable_dataclass.py | 10 ++ .../test_sdc_defaults.py | 3 +- 3 files changed, 64 insertions(+), 52 deletions(-) diff --git a/muutils/dictmagic.py b/muutils/dictmagic.py index 1301b719..ac051d0b 100644 --- a/muutils/dictmagic.py +++ b/muutils/dictmagic.py @@ -1,13 +1,14 @@ +from __future__ import annotations import typing import warnings from collections import defaultdict -from typing import Any, Callable, Generic, Hashable, Iterable, Literal, TypeVar +from typing import Any, Callable, Generic, Hashable, Iterable, Literal, TypeVar, Dict _KT = TypeVar("_KT") _VT = TypeVar("_VT") -class DefaulterDict(dict[_KT, _VT], Generic[_KT, _VT]): +class DefaulterDict(Dict[_KT, _VT], Generic[_KT, _VT]): """like a defaultdict, but default_factory is passed the key as an argument""" def __init__(self, default_factory: Callable[[_KT], _VT], *args, **kwargs): @@ -41,7 +42,7 @@ def defaultdict_to_dict_recursive(dd: defaultdict | DefaulterDict) -> dict: } -def dotlist_to_nested_dict(dot_dict: dict[str, Any], sep: str = ".") -> dict[str, Any]: +def dotlist_to_nested_dict(dot_dict: Dict[str, Any], sep: str = ".") -> Dict[str, Any]: """Convert a dict with dot-separated keys to a nested dict Example: @@ -62,11 +63,11 @@ def dotlist_to_nested_dict(dot_dict: dict[str, Any], sep: str = ".") -> dict[str def nested_dict_to_dotlist( - nested_dict: dict[str, Any], + nested_dict: Dict[str, Any], sep: str = ".", allow_lists: bool = False, -) -> dict[str, Any]: - def _recurse(current: Any, parent_key: str = "") -> dict[str, Any]: +) -> Dict[str, Any]: + def _recurse(current: Any, parent_key: str = "") -> Dict[str, Any]: items: dict = dict() new_key: str @@ -95,9 +96,9 @@ def _recurse(current: Any, parent_key: str = "") -> dict[str, Any]: def update_with_nested_dict( - original: dict[str, Any], - update: dict[str, Any], -) -> dict[str, Any]: + original: Dict[str, Any], + update: Dict[str, Any], +) -> Dict[str, Any]: """Update a dict with a nested dict Example: @@ -105,9 +106,9 @@ def update_with_nested_dict( {'a': {'b': 2}, 'c': -1} # Arguments - - `original: dict[str, Any]` + - `original: Dict[str, Any]` the dict to update (will be modified in-place) - - `update: dict[str, Any]` + - `update: Dict[str, Any]` the dict to update with # Returns @@ -127,12 +128,12 @@ def update_with_nested_dict( def kwargs_to_nested_dict( - kwargs_dict: dict[str, Any], + kwargs_dict: Dict[str, Any], sep: str = ".", strip_prefix: str | None = None, when_unknown_prefix: typing.Literal["raise", "warn", "ignore"] = "warn", transform_key: Callable[[str], str] | None = None, -) -> dict[str, Any]: +) -> Dict[str, Any]: """given kwargs from fire, convert them to a nested dict if strip_prefix is not None, then all keys must start with the prefix. by default, @@ -152,7 +153,7 @@ def main(**kwargs): ``` # Arguments - - `kwargs_dict: dict[str, Any]` + - `kwargs_dict: Dict[str, Any]` the kwargs dict to convert - `sep: str = "."` the separator to use for nested keys @@ -163,7 +164,7 @@ def main(**kwargs): - `transform_key: Callable[[str], str] | None = None` a function to apply to each key before adding it to the dict (applied after stripping the prefix) """ - filtered_kwargs: dict[str, Any] = dict() + filtered_kwargs: Dict[str, Any] = dict() for key, value in kwargs_dict.items(): if strip_prefix is not None: if not key.startswith(strip_prefix): @@ -197,8 +198,8 @@ def is_numeric_consecutive(lst: list[str]) -> bool: def condense_nested_dicts_numeric_keys( - data: dict[str, Any], -) -> dict[str, Any]: + data: Dict[str, Any], +) -> Dict[str, Any]: """condense a nested dict, by condensing numeric keys with matching values to ranges # Examples: @@ -224,7 +225,7 @@ def condense_nested_dicts_numeric_keys( return data # output dict - condensed_data: dict[str, Any] = {} + condensed_data: Dict[str, Any] = {} # Identify ranges of identical values and condense i: int = 0 @@ -244,15 +245,15 @@ def condense_nested_dicts_numeric_keys( def condense_nested_dicts_matching_values( - data: dict[str, Any], + data: Dict[str, Any], val_condense_fallback_mapping: Callable[[Any], Hashable] | None = None, -) -> dict[str, Any]: +) -> Dict[str, Any]: """condense a nested dict, by condensing keys with matching values # Examples: # Parameters: - - `data : dict[str, Any]` + - `data : Dict[str, Any]` data to process - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None` a function to apply to each value before adding it to the dict (if it's not hashable) @@ -272,7 +273,7 @@ def condense_nested_dicts_matching_values( # Find all identical values and condense by stitching together keys values_grouped: defaultdict[Any, list[str]] = defaultdict(list) - data_persist: dict[str, Any] = dict() + data_persist: Dict[str, Any] = dict() for key, value in data.items(): if not isinstance(value, dict): try: @@ -298,11 +299,11 @@ def condense_nested_dicts_matching_values( def condense_nested_dicts( - data: dict[str, Any], + data: Dict[str, Any], condense_numeric_keys: bool = True, condense_matching_values: bool = True, val_condense_fallback_mapping: Callable[[Any], Hashable] | None = None, -) -> dict[str, Any]: +) -> Dict[str, Any]: """condense a nested dict, by condensing numeric or matching keys with matching values to ranges combines the functionality of `condense_nested_dicts_numeric_keys()` and `condense_nested_dicts_matching_values()` @@ -311,7 +312,7 @@ def condense_nested_dicts( it's not reversible because types are lost to make the printing pretty # Parameters: - - `data : dict[str, Any]` + - `data : Dict[str, Any]` data to process - `condense_numeric_keys : bool` whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]") @@ -336,7 +337,7 @@ def condense_nested_dicts( def tuple_dims_replace( - t: tuple[int, ...], dims_names_map: dict[int, str] | None = None + t: tuple[int, ...], dims_names_map: Dict[int, str] | None = None ) -> tuple[int | str, ...]: if dims_names_map is None: return t @@ -344,7 +345,7 @@ def tuple_dims_replace( return tuple(dims_names_map.get(x, x) for x in t) -TensorDict = dict[str, "torch.Tensor|np.ndarray"] # type: ignore[name-defined] +TensorDict = Dict[str, "torch.Tensor|np.ndarray"] # type: ignore[name-defined] TensorIterable = Iterable[tuple[str, "torch.Tensor|np.ndarray"]] # type: ignore[name-defined] TensorDictFormats = Literal["dict", "json", "yaml", "yml"] @@ -360,19 +361,19 @@ def condense_tensor_dict( shapes_convert: Callable[[tuple], Any] = _default_shapes_convert, drop_batch_dims: int = 0, sep: str = ".", - dims_names_map: dict[int, str] | None = None, + dims_names_map: Dict[int, str] | None = None, condense_numeric_keys: bool = True, condense_matching_values: bool = True, val_condense_fallback_mapping: Callable[[Any], Hashable] | None = None, return_format: TensorDictFormats | None = None, -) -> str | dict[str, str | tuple[int, ...]]: +) -> str | Dict[str, str | tuple[int, ...]]: """Convert a dictionary of tensors to a dictionary of shapes. by default, values are converted to strings of their shapes (for nice printing). If you want the actual shapes, set `shapes_convert = lambda x: x` or `shapes_convert = None`. # Parameters: - - `data : dict[str, "torch.Tensor|np.ndarray"] | Iterable[tuple[str, "torch.Tensor|np.ndarray"]]` + - `data : Dict[str, "torch.Tensor|np.ndarray"] | Iterable[tuple[str, "torch.Tensor|np.ndarray"]]` a either a `TensorDict` dict from strings to tensors, or an `TensorIterable` iterable of (key, tensor) pairs (like you might get from a `dict().items())` ) - `fmt : TensorDictFormats` format to return the result in -- either a dict, or dump to json/yaml directly for pretty printing. will crash if yaml is not installed. @@ -386,7 +387,7 @@ def condense_tensor_dict( - `sep : str` separator to use for nested keys (defaults to `'.'`) - - `dims_names_map : dict[int, str] | None` + - `dims_names_map : Dict[int, str] | None` convert certain dimension values in shape. not perfect, can be buggy (defaults to `None`) - `condense_numeric_keys : bool` @@ -402,7 +403,7 @@ def condense_tensor_dict( legacy alias for `fmt` kwarg # Returns: - - `str|dict[str, str|tuple[int, ...]]` + - `str|Dict[str, str|tuple[int, ...]]` dict if `return_format='dict'`, a string for `json` or `yaml` output # Examples: @@ -458,7 +459,7 @@ def condense_tensor_dict( ) # get shapes - data_shapes: dict[str, str | tuple[int, ...]] = { + data_shapes: Dict[str, str | tuple[int, ...]] = { k: shapes_convert( tuple_dims_replace( tuple(v.shape)[drop_batch_dims:], @@ -469,10 +470,10 @@ def condense_tensor_dict( } # nest the dict - data_nested: dict[str, Any] = dotlist_to_nested_dict(data_shapes, sep=sep) + data_nested: Dict[str, Any] = dotlist_to_nested_dict(data_shapes, sep=sep) # condense the nested dict - data_condensed: dict[str, str | tuple[int, ...]] = condense_nested_dicts( + data_condensed: Dict[str, str | tuple[int, ...]] = condense_nested_dicts( data=data_nested, condense_numeric_keys=condense_numeric_keys, condense_matching_values=condense_matching_values, @@ -480,19 +481,19 @@ def condense_tensor_dict( ) # return in the specified format - match fmt.lower(): - case "dict": - return data_condensed - case "json": - import json - - return json.dumps(data_condensed, indent=2) - case "yaml" | "yml": - try: - import yaml # type: ignore[import-untyped] - - return yaml.dump(data_condensed, sort_keys=False) - except ImportError as e: - raise ValueError("PyYAML is required for YAML output") from e - case _: - raise ValueError(f"Invalid return format: {fmt}") + fmt_lower: str = fmt.lower() + if fmt_lower == "dict": + return data_condensed + elif fmt_lower == "json": + import json + + return json.dumps(data_condensed, indent=2) + elif fmt_lower in ["yaml", "yml"]: + try: + import yaml # type: ignore[import-untyped] + + return yaml.dump(data_condensed, sort_keys=False) + except ImportError as e: + raise ValueError("PyYAML is required for YAML output") from e + else: + raise ValueError(f"Invalid return format: {fmt}") diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index 8ffca827..3c49e003 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -2,6 +2,7 @@ import abc import dataclasses import json +import sys import types import typing import warnings @@ -72,6 +73,15 @@ def __init__( else: super_kwargs["metadata"] = types.MappingProxyType({}) + # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy + if (sys.version_info[1] < 9): + if (super_kwargs["kw_only"] == True): # noqa: E712 + raise ValueError( + "kw_only is not supported in python >=3.9" + ) + else: + del super_kwargs["kw_only"] + # actually init the super class super().__init__(**super_kwargs) # type: ignore[call-arg] diff --git a/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py b/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py index 1272fc46..ef373bf3 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py @@ -1,3 +1,4 @@ +from typing import Dict from muutils.json_serialize import ( JsonSerializer, SerializableDataclass, @@ -42,7 +43,7 @@ def test_sdc_strip_format_jser(): assert recovered == instance -TYPE_MAP: dict[str, type] = {x.__name__: x for x in [int, float, str, bool]} +TYPE_MAP: Dict[str, type] = {x.__name__: x for x in [int, float, str, bool]} @serializable_dataclass From df4a6c95ec2a1af0b82501751be4eb1c7fc7e1e4 Mon Sep 17 00:00:00 2001 From: mivanit Date: Mon, 17 Jun 2024 19:38:17 -0700 Subject: [PATCH 007/158] fix a bunch of type hints --- muutils/dictmagic.py | 38 +++++++++---------- muutils/group_equiv.py | 1 + .../json_serialize/serializable_dataclass.py | 10 +++++ muutils/logger/headerfuncs.py | 1 + muutils/logger/logger.py | 1 + muutils/logger/loggingstream.py | 1 + muutils/logger/simplelogger.py | 5 ++- muutils/logger/timing.py | 1 + muutils/misc.py | 35 ++++++++--------- muutils/nbutils/convert_ipynb_to_script.py | 1 + muutils/statcounter.py | 8 ++-- muutils/sysinfo.py | 12 +++--- muutils/tensor_utils.py | 20 +++++----- .../test_sdc_properties_nested.py | 14 ++++++- .../test_serializable_dataclass.py | 1 + tests/unit/logger/test_logger.py | 2 + tests/unit/logger/test_timer_context.py | 2 + tests/unit/misc/test_freeze.py | 1 + tests/unit/misc/test_misc.py | 1 + tests/unit/misc/test_numerical_conversions.py | 1 + tests/unit/nbutils/test_conversion.py | 1 + tests/unit/test_group_equiv.py | 3 +- tests/unit/test_statcounter.py | 1 + tests/unit/test_tensor_utils.py | 1 + 24 files changed, 102 insertions(+), 60 deletions(-) diff --git a/muutils/dictmagic.py b/muutils/dictmagic.py index ac051d0b..555ddd74 100644 --- a/muutils/dictmagic.py +++ b/muutils/dictmagic.py @@ -2,7 +2,7 @@ import typing import warnings from collections import defaultdict -from typing import Any, Callable, Generic, Hashable, Iterable, Literal, TypeVar, Dict +from typing import Any, Callable, Generic, Hashable, Iterable, Literal, TypeVar, Dict, Union, Optional, Tuple _KT = TypeVar("_KT") _VT = TypeVar("_VT") @@ -30,7 +30,7 @@ def _recursive_defaultdict_ctor() -> defaultdict: return defaultdict(_recursive_defaultdict_ctor) -def defaultdict_to_dict_recursive(dd: defaultdict | DefaulterDict) -> dict: +def defaultdict_to_dict_recursive(dd: Union[defaultdict, DefaulterDict]) -> dict: """Convert a defaultdict or DefaulterDict to a normal dict, recursively""" return { key: ( @@ -130,9 +130,9 @@ def update_with_nested_dict( def kwargs_to_nested_dict( kwargs_dict: Dict[str, Any], sep: str = ".", - strip_prefix: str | None = None, + strip_prefix: Optional[str] = None, when_unknown_prefix: typing.Literal["raise", "warn", "ignore"] = "warn", - transform_key: Callable[[str], str] | None = None, + transform_key: Optional[Callable[[str], str]] = None, ) -> Dict[str, Any]: """given kwargs from fire, convert them to a nested dict @@ -157,7 +157,7 @@ def main(**kwargs): the kwargs dict to convert - `sep: str = "."` the separator to use for nested keys - - `strip_prefix: str | None = None` + - `strip_prefix: Optional[str] = None` if not None, then all keys must start with this prefix - `when_unknown_prefix: typing.Literal["raise", "warn", "ignore"] = "warn"` what to do when an unknown prefix is found @@ -246,7 +246,7 @@ def condense_nested_dicts_numeric_keys( def condense_nested_dicts_matching_values( data: Dict[str, Any], - val_condense_fallback_mapping: Callable[[Any], Hashable] | None = None, + val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None, ) -> Dict[str, Any]: """condense a nested dict, by condensing keys with matching values @@ -302,7 +302,7 @@ def condense_nested_dicts( data: Dict[str, Any], condense_numeric_keys: bool = True, condense_matching_values: bool = True, - val_condense_fallback_mapping: Callable[[Any], Hashable] | None = None, + val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None, ) -> Dict[str, Any]: """condense a nested dict, by condensing numeric or matching keys with matching values to ranges @@ -337,8 +337,8 @@ def condense_nested_dicts( def tuple_dims_replace( - t: tuple[int, ...], dims_names_map: Dict[int, str] | None = None -) -> tuple[int | str, ...]: + t: Tuple[int, ...], dims_names_map: Optional[Dict[int, str]] = None +) -> Tuple[Union[int, str], ...]: if dims_names_map is None: return t else: @@ -346,7 +346,7 @@ def tuple_dims_replace( TensorDict = Dict[str, "torch.Tensor|np.ndarray"] # type: ignore[name-defined] -TensorIterable = Iterable[tuple[str, "torch.Tensor|np.ndarray"]] # type: ignore[name-defined] +TensorIterable = Iterable[Tuple[str, "torch.Tensor|np.ndarray"]] # type: ignore[name-defined] TensorDictFormats = Literal["dict", "json", "yaml", "yml"] @@ -361,19 +361,19 @@ def condense_tensor_dict( shapes_convert: Callable[[tuple], Any] = _default_shapes_convert, drop_batch_dims: int = 0, sep: str = ".", - dims_names_map: Dict[int, str] | None = None, + dims_names_map: Optional[Dict[int, str]] = None, condense_numeric_keys: bool = True, condense_matching_values: bool = True, - val_condense_fallback_mapping: Callable[[Any], Hashable] | None = None, - return_format: TensorDictFormats | None = None, -) -> str | Dict[str, str | tuple[int, ...]]: + val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None, + return_format: Optional[TensorDictFormats] = None, +) -> Union[str, Dict[str, str | Tuple[int, ...]]]: """Convert a dictionary of tensors to a dictionary of shapes. by default, values are converted to strings of their shapes (for nice printing). If you want the actual shapes, set `shapes_convert = lambda x: x` or `shapes_convert = None`. # Parameters: - - `data : Dict[str, "torch.Tensor|np.ndarray"] | Iterable[tuple[str, "torch.Tensor|np.ndarray"]]` + - `data : Dict[str, "torch.Tensor|np.ndarray"] | Iterable[Tuple[str, "torch.Tensor|np.ndarray"]]` a either a `TensorDict` dict from strings to tensors, or an `TensorIterable` iterable of (key, tensor) pairs (like you might get from a `dict().items())` ) - `fmt : TensorDictFormats` format to return the result in -- either a dict, or dump to json/yaml directly for pretty printing. will crash if yaml is not installed. @@ -403,7 +403,7 @@ def condense_tensor_dict( legacy alias for `fmt` kwarg # Returns: - - `str|Dict[str, str|tuple[int, ...]]` + - `str|Dict[str, str|Tuple[int, ...]]` dict if `return_format='dict'`, a string for `json` or `yaml` output # Examples: @@ -454,12 +454,12 @@ def condense_tensor_dict( shapes_convert = lambda x: x # convert to iterable - data_items: Iterable[tuple[str, "torch.Tensor|np.ndarray"]] = ( # type: ignore + data_items: "Iterable[Tuple[str, Union[torch.Tensor,np.ndarray]]]" = ( # type: ignore data.items() if hasattr(data, "items") and callable(data.items) else data # type: ignore ) # get shapes - data_shapes: Dict[str, str | tuple[int, ...]] = { + data_shapes: Dict[str, Union[str, Tuple[int, ...]]] = { k: shapes_convert( tuple_dims_replace( tuple(v.shape)[drop_batch_dims:], @@ -473,7 +473,7 @@ def condense_tensor_dict( data_nested: Dict[str, Any] = dotlist_to_nested_dict(data_shapes, sep=sep) # condense the nested dict - data_condensed: Dict[str, str | tuple[int, ...]] = condense_nested_dicts( + data_condensed: Dict[str, Union[str, Tuple[int, ...]]] = condense_nested_dicts( data=data_nested, condense_numeric_keys=condense_numeric_keys, condense_matching_values=condense_matching_values, diff --git a/muutils/group_equiv.py b/muutils/group_equiv.py index 0a45a55a..fa50dadb 100644 --- a/muutils/group_equiv.py +++ b/muutils/group_equiv.py @@ -1,3 +1,4 @@ +from __future__ import annotations from itertools import chain from typing import Callable, Sequence, TypeVar diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index 3c49e003..f60e059d 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -432,6 +432,16 @@ def wrap(cls: Type[T]) -> Type[T]: field_value = serializable_field() setattr(cls, field_name, field_value) + # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy + if (sys.version_info[1] < 9): + if "kw_only" in kwargs: + if (kwargs["kw_only"] == True): # noqa: E712 + raise ValueError( + "kw_only is not supported in python >=3.9" + ) + else: + del kwargs["kw_only"] + cls = dataclasses.dataclass( # type: ignore[call-overload] cls, init=init, diff --git a/muutils/logger/headerfuncs.py b/muutils/logger/headerfuncs.py index 11b97c1a..09257de9 100644 --- a/muutils/logger/headerfuncs.py +++ b/muutils/logger/headerfuncs.py @@ -1,3 +1,4 @@ +from __future__ import annotations import json from typing import Any, Mapping, Protocol diff --git a/muutils/logger/logger.py b/muutils/logger/logger.py index 03280db0..d888cf97 100644 --- a/muutils/logger/logger.py +++ b/muutils/logger/logger.py @@ -6,6 +6,7 @@ this was mostly made with training models in mind and storing both metadata and loss - `TimerContext` is a context manager that can be used to time the duration of a block of code """ +from __future__ import annotations import json import time diff --git a/muutils/logger/loggingstream.py b/muutils/logger/loggingstream.py index 16dd742f..a28bdc8f 100644 --- a/muutils/logger/loggingstream.py +++ b/muutils/logger/loggingstream.py @@ -1,3 +1,4 @@ +from __future__ import annotations import time from dataclasses import dataclass, field from typing import Any, Callable diff --git a/muutils/logger/simplelogger.py b/muutils/logger/simplelogger.py index bce0046b..1f9adda8 100644 --- a/muutils/logger/simplelogger.py +++ b/muutils/logger/simplelogger.py @@ -1,8 +1,9 @@ +from __future__ import annotations import json import sys import time import typing -from typing import TextIO +from typing import TextIO, Union from muutils.json_serialize import JSONitem, json_serialize @@ -26,7 +27,7 @@ def close(self) -> None: pass -AnyIO = TextIO | NullIO +AnyIO = Union[TextIO, NullIO] class SimpleLogger: diff --git a/muutils/logger/timing.py b/muutils/logger/timing.py index ba8e43bd..56102c55 100644 --- a/muutils/logger/timing.py +++ b/muutils/logger/timing.py @@ -1,3 +1,4 @@ +from __future__ import annotations import time from typing import Literal diff --git a/muutils/misc.py b/muutils/misc.py index cc2ea074..bf38f82d 100644 --- a/muutils/misc.py +++ b/muutils/misc.py @@ -1,3 +1,4 @@ +from __future__ import annotations import hashlib import typing @@ -210,23 +211,23 @@ def str_to_numeric( # detect if it has a suffix suffixes_detected: list[bool] = [suffix in quantity for suffix in _mapping] - match sum(suffixes_detected): - case 0: - # no suffix - pass - case 1: - # find multiplier - for suffix, mult in _mapping.items(): - if quantity.endswith(suffix): - # remove suffix, store multiplier, and break - quantity = quantity.removesuffix(suffix).strip() - multiplier = mult - break - else: - raise ValueError(f"Invalid suffix in {quantity_original}") - case _: - # multiple suffixes - raise ValueError(f"Multiple suffixes detected in {quantity_original}") + n_suffixes_detected: int = sum(suffixes_detected) + if n_suffixes_detected == 0: + # no suffix + pass + elif n_suffixes_detected == 1: + # find multiplier + for suffix, mult in _mapping.items(): + if quantity.endswith(suffix): + # remove suffix, store multiplier, and break + quantity = quantity.removesuffix(suffix).strip() + multiplier = mult + break + else: + raise ValueError(f"Invalid suffix in {quantity_original}") + else: + # multiple suffixes + raise ValueError(f"Multiple suffixes detected in {quantity_original}") # fractions if "/" in quantity: diff --git a/muutils/nbutils/convert_ipynb_to_script.py b/muutils/nbutils/convert_ipynb_to_script.py index 130d0750..91eae765 100644 --- a/muutils/nbutils/convert_ipynb_to_script.py +++ b/muutils/nbutils/convert_ipynb_to_script.py @@ -1,3 +1,4 @@ +from __future__ import annotations import argparse import json import os diff --git a/muutils/statcounter.py b/muutils/statcounter.py index 53e466d0..d754758f 100644 --- a/muutils/statcounter.py +++ b/muutils/statcounter.py @@ -1,9 +1,9 @@ +from __future__ import annotations import json import math from collections import Counter from functools import cached_property from itertools import chain -from types import NoneType from typing import Callable, Optional, Sequence, Union # _GeneralArray = Union[np.ndarray, "torch.Tensor"] @@ -16,7 +16,7 @@ def universal_flatten( - arr: NumericSequence | float | int, require_rectangular: bool = True + arr: Union[NumericSequence, float, int], require_rectangular: bool = True ) -> NumericSequence: """flattens any iterable""" @@ -24,7 +24,7 @@ def universal_flatten( if hasattr(arr, "flatten") and callable(arr.flatten): # type: ignore return arr.flatten() # type: ignore elif isinstance(arr, Sequence): - elements_iterable: list[bool] = [isinstance(x, Sequence) for x in arr] + elements_iterable: List[bool] = [isinstance(x, Sequence) for x in arr] if require_rectangular and (all(elements_iterable) != any(elements_iterable)): raise ValueError("arr contains mixed iterable and non-iterable elements") if any(elements_iterable): @@ -47,7 +47,7 @@ class StatCounter(Counter): def validate(self) -> bool: """validate the counter as being all floats or ints""" - return all(isinstance(k, (bool, int, float, NoneType)) for k in self.keys()) + return all(isinstance(k, (bool, int, float, type(None))) for k in self.keys()) def min(self): return min(x for x, v in self.items() if v > 0) diff --git a/muutils/sysinfo.py b/muutils/sysinfo.py index 4efa6325..543a4c85 100644 --- a/muutils/sysinfo.py +++ b/muutils/sysinfo.py @@ -6,12 +6,12 @@ from pip._internal.operations.freeze import freeze as pip_freeze -def _popen(cmd: list[str], split_out: bool = False) -> dict[str, typing.Any]: +def _popen(cmd: typing.List[str], split_out: bool = False) -> typing.Dict[str, typing.Any]: p: subprocess.Popen = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) - p_out: str | list[str] | None + p_out: typing.Union[str, typing.List[str], None] if p.stdout is not None: p_out = p.stdout.read().decode("utf-8") if split_out: @@ -46,7 +46,7 @@ def python() -> dict: @staticmethod def pip() -> dict: """installed packages info""" - pckgs: list[str] = [x for x in pip_freeze(local_only=True)] + pckgs: typing.List[str] = [x for x in pip_freeze(local_only=True)] return { "n_packages": len(pckgs), "packages": pckgs, @@ -158,10 +158,10 @@ def git_info() -> dict: @classmethod def get_all( cls, - include: tuple[str, ...] | None = None, - exclude: tuple[str, ...] = tuple(), + include: typing.Optional[typing.Tuple[str, ...]] = None, + exclude: typing.Tuple[str, ...] = tuple(), ) -> dict: - include_meta: tuple[str, ...] + include_meta: typing.Tuple[str, ...] if include is None: include_meta = tuple(cls.__dict__.keys()) else: diff --git a/muutils/tensor_utils.py b/muutils/tensor_utils.py index 97e8446b..4c0a50f9 100644 --- a/muutils/tensor_utils.py +++ b/muutils/tensor_utils.py @@ -99,7 +99,7 @@ def param_info(cls, params) -> str: ) @typing._tp_cache # type: ignore - def __class_getitem__(cls, params: str | tuple) -> type: + def __class_getitem__(cls, params: typing.Union[str, tuple]) -> type: # MyTensor["dim1 dim2"] if isinstance(params, str): return default_jax_dtype[array_type, params] @@ -124,7 +124,7 @@ def __class_getitem__(cls, params: str | tuple) -> type: f"legacy type annotation was used:\n{cls.param_info(params)}" ) # MyTensor[("dim1", "dim2"), int] - shape_anot: list[str] = list() + shape_anot: typing.List[str] = list() for x in params[0]: if isinstance(x, str): shape_anot.append(x) @@ -178,7 +178,7 @@ def __class_getitem__(cls, params): NDArray = jaxtype_factory("NDArray", np.ndarray, jaxtyping.Float) # type: ignore[misc, assignment] -def numpy_to_torch_dtype(dtype: np.dtype | torch.dtype) -> torch.dtype: +def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype: """convert numpy dtype to torch dtype""" if isinstance(dtype, torch.dtype): return dtype @@ -261,7 +261,7 @@ def numpy_to_torch_dtype(dtype: np.dtype | torch.dtype) -> torch.dtype: TORCH_DTYPE_MAP["bool"] = torch.bool -TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.Optimizer]] = { +TORCH_OPTIMIZERS_MAP: typing.Dict[str, typing.Type[torch.optim.Optimizer]] = { "Adagrad": torch.optim.Adagrad, "Adam": torch.optim.Adam, "AdamW": torch.optim.AdamW, @@ -287,7 +287,7 @@ def pad_tensor( set `rpad = True` to pad on the right instead""" - temp: list[torch.Tensor] = [ + temp: typing.List[torch.Tensor] = [ torch.full( (padded_length - tensor.shape[0],), pad_value, @@ -326,7 +326,7 @@ def pad_array( set `rpad = True` to pad on the right instead""" - temp: list[np.ndarray] = [ + temp: typing.List[np.ndarray] = [ np.full( (padded_length - array.shape[0],), pad_value, @@ -355,12 +355,12 @@ def rpad_array( return pad_array(array, pad_length, pad_value, rpad=True) -def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]: +def get_dict_shapes(d: typing.Dict[str, "torch.Tensor"]) -> typing.Dict[str, typing.Tuple[int, ...]]: """given a state dict or cache dict, compute the shapes and put them in a nested dict""" return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()}) -def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str: +def string_dict_shapes(d: typing.Dict[str, "torch.Tensor"]) -> str: """printable version of get_dict_shapes""" return json.dumps( dotlist_to_nested_dict( @@ -428,8 +428,8 @@ def compare_state_dicts( ) # check tensors match - shape_failed: list[str] = list() - vals_failed: list[str] = list() + shape_failed: typing.List[str] = list() + vals_failed: typing.List[str] = list() for k, v1 in d1.items(): v2 = d2[k] # check shapes first diff --git a/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py b/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py index 5e951b01..3e22fac7 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py @@ -1,5 +1,13 @@ +from __future__ import annotations +import sys + +import pytest + from muutils.json_serialize import SerializableDataclass, serializable_dataclass +SUPPORS_KW_ONLY: bool = sys.version_info[1] >= 9 + +print(f"{SUPPORS_KW_ONLY = }") @serializable_dataclass class Person(SerializableDataclass): @@ -12,7 +20,7 @@ def full_name(self) -> str: @serializable_dataclass( - kw_only=True, properties_to_serialize=["full_name", "full_title"] + kw_only=SUPPORS_KW_ONLY, properties_to_serialize=["full_name", "full_title"] ) class TitledPerson(Person): title: str @@ -41,6 +49,10 @@ def test_serialize_person(): def test_serialize_titled_person(): instance = TitledPerson(first_name="Jane", last_name="Smith", title="Dr.") + if SUPPORS_KW_ONLY: + with pytest.raises(TypeError): + TitledPerson("Jane", "Smith", "Dr.") + serialized = instance.serialize() assert serialized == { diff --git a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py index 2ebcd2fb..56423380 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py @@ -1,3 +1,4 @@ +from __future__ import annotations from typing import Any import pytest diff --git a/tests/unit/logger/test_logger.py b/tests/unit/logger/test_logger.py index 29340d29..15be01dc 100644 --- a/tests/unit/logger/test_logger.py +++ b/tests/unit/logger/test_logger.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from muutils.logger import Logger diff --git a/tests/unit/logger/test_timer_context.py b/tests/unit/logger/test_timer_context.py index 63e48786..cd9d8bc2 100644 --- a/tests/unit/logger/test_timer_context.py +++ b/tests/unit/logger/test_timer_context.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from muutils.logger.timing import TimerContext diff --git a/tests/unit/misc/test_freeze.py b/tests/unit/misc/test_freeze.py index a0a2bd99..238ef6f9 100644 --- a/tests/unit/misc/test_freeze.py +++ b/tests/unit/misc/test_freeze.py @@ -1,3 +1,4 @@ +from __future__ import annotations import pytest from muutils.misc import freeze diff --git a/tests/unit/misc/test_misc.py b/tests/unit/misc/test_misc.py index 01c0aec8..e22380af 100644 --- a/tests/unit/misc/test_misc.py +++ b/tests/unit/misc/test_misc.py @@ -1,3 +1,4 @@ +from __future__ import annotations import pytest from muutils.misc import ( diff --git a/tests/unit/misc/test_numerical_conversions.py b/tests/unit/misc/test_numerical_conversions.py index f1b0242b..f2175f2e 100644 --- a/tests/unit/misc/test_numerical_conversions.py +++ b/tests/unit/misc/test_numerical_conversions.py @@ -1,3 +1,4 @@ +from __future__ import annotations import random from math import isclose, isinf, isnan diff --git a/tests/unit/nbutils/test_conversion.py b/tests/unit/nbutils/test_conversion.py index 52f9f01f..8301f2da 100644 --- a/tests/unit/nbutils/test_conversion.py +++ b/tests/unit/nbutils/test_conversion.py @@ -1,3 +1,4 @@ +from __future__ import annotations import itertools import os diff --git a/tests/unit/test_group_equiv.py b/tests/unit/test_group_equiv.py index 80e8bc3b..a2499fc3 100644 --- a/tests/unit/test_group_equiv.py +++ b/tests/unit/test_group_equiv.py @@ -1,4 +1,5 @@ -# Assuming your functions are in a file named `group_by.py` +from __future__ import annotations + from muutils.group_equiv import group_by_equivalence diff --git a/tests/unit/test_statcounter.py b/tests/unit/test_statcounter.py index d10f1f04..d373cf95 100644 --- a/tests/unit/test_statcounter.py +++ b/tests/unit/test_statcounter.py @@ -1,3 +1,4 @@ +from __future__ import annotations import numpy as np from muutils.statcounter import StatCounter diff --git a/tests/unit/test_tensor_utils.py b/tests/unit/test_tensor_utils.py index 1574edba..17c4ed15 100644 --- a/tests/unit/test_tensor_utils.py +++ b/tests/unit/test_tensor_utils.py @@ -1,3 +1,4 @@ +from __future__ import annotations import jaxtyping import numpy as np import pytest From f4724eb202d04e6d6d93315572de983c17cd018f Mon Sep 17 00:00:00 2001 From: mivanit Date: Mon, 17 Jun 2024 20:06:54 -0700 Subject: [PATCH 008/158] sanitize_name function, with wrappers for fnames and identifiers --- muutils/misc.py | 76 ++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 62 insertions(+), 14 deletions(-) diff --git a/muutils/misc.py b/muutils/misc.py index bf38f82d..a83a930a 100644 --- a/muutils/misc.py +++ b/muutils/misc.py @@ -62,26 +62,74 @@ def list_join(lst: list, factory: typing.Callable) -> list: return output -# filename stuff +# name stuff # ================================================================================ +def sanitize_name( + name: str | None, + additional_allowed_chars: str = "", + replace_invalid: str = "", + when_none: str | None = "_None_", + leading_digit_prefix: str = "", + ) -> str: + """sanitize a string, leaving only alphanumerics and `additional_allowed_chars` + + # Parameters: + - `name : str | None` + input string + - `additional_allowed_chars : str` + additional characters to allow, none by default + (defaults to `""`) + - `replace_invalid : str` + character to replace invalid characters with + (defaults to `""`) + - `when_none : str | None` + string to return if `name` is `None`. if `None`, raises an exception + (defaults to `"_None_"`) + - `leading_digit_prefix : str` + character to prefix the string with if it starts with a digit + (defaults to `""`) + + # Returns: + - `str` + sanitized string + """ + + + if name is None: + if when_none is None: + raise ValueError("name is None") + else: + return when_none -def sanitize_fname(fname: str | None) -> str: - """sanitize a filename for use in a path""" - if fname is None: - return "_None_" - - fname_sanitized: str = "" - for char in fname: + sanitized: str = "" + for char in name: if char.isalnum(): - fname_sanitized += char - elif char in ("-", "_", "."): - fname_sanitized += char + sanitized += char + elif char in additional_allowed_chars: + sanitized += char else: - fname_sanitized += "" + sanitized += replace_invalid + + if sanitized[0].isdigit(): + sanitized = leading_digit_prefix + sanitized - return fname_sanitized + return sanitized +def sanitize_fname(fname: str | None, **kwargs) -> str: + """sanitize a filename to posix standards + + - leave only alphanumerics, `_` (underscore), '-' (dash) and `.` (period) + """ + return sanitize_name(fname, additional_allowed_chars="._-", **kwargs) + +def sanitize_identifier(fname: str | None, **kwargs) -> str: + """sanitize an identifier (variable or function name) + + - leave only alphanumerics and `_` (underscore) + - prefix with `_` if it starts with a digit + """ + return sanitize_name(fname, additional_allowed_chars="_", leading_digit_prefix="_", **kwargs) def dict_to_filename( data: dict, @@ -220,7 +268,7 @@ def str_to_numeric( for suffix, mult in _mapping.items(): if quantity.endswith(suffix): # remove suffix, store multiplier, and break - quantity = quantity.removesuffix(suffix).strip() + quantity = quantity[: -len(suffix)].strip() multiplier = mult break else: From fd2d413b538cbf3f49f7f66c50f1944720551701 Mon Sep 17 00:00:00 2001 From: mivanit Date: Mon, 17 Jun 2024 20:12:46 -0700 Subject: [PATCH 009/158] tests for new sanitize_name and derived --- tests/unit/misc/test_misc.py | 38 ++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/unit/misc/test_misc.py b/tests/unit/misc/test_misc.py index e22380af..c752f725 100644 --- a/tests/unit/misc/test_misc.py +++ b/tests/unit/misc/test_misc.py @@ -1,12 +1,15 @@ from __future__ import annotations import pytest +import pytest from muutils.misc import ( dict_to_filename, freeze, list_join, list_split, + sanitize_name, sanitize_fname, + sanitize_identifier, stable_hash, ) @@ -34,6 +37,41 @@ def test_sanitize_fname(): assert sanitize_fname(None) == "_None_", "None input should return '_None_'" + + +def test_sanitize_name(): + assert sanitize_name("Hello World") == "HelloWorld" + assert sanitize_name("Hello_World", additional_allowed_chars="_") == "Hello_World" + assert sanitize_name("Hello!World", replace_invalid="-") == "Hello-World" + assert sanitize_name(None) == "_None_" + assert sanitize_name(None, when_none="Empty") == "Empty" + with pytest.raises(ValueError): + sanitize_name(None, when_none=None) + assert sanitize_name("123abc") == "123abc" + assert sanitize_name("123abc", leading_digit_prefix="_") == "_123abc" + +def test_sanitize_fname_2(): + assert sanitize_fname("file name.txt") == "filename.txt" + assert sanitize_fname("file_name.txt") == "file_name.txt" + assert sanitize_fname("file-name.txt") == "file-name.txt" + assert sanitize_fname("file!name.txt") == "filename.txt" + assert sanitize_fname(None) == "_None_" + assert sanitize_fname(None, when_none="Empty") == "Empty" + with pytest.raises(ValueError): + sanitize_fname(None, when_none=None) + assert sanitize_fname("123file.txt") == "123file.txt" + assert sanitize_fname("123file.txt", leading_digit_prefix="_") == "_123file.txt" + +def test_sanitize_identifier(): + assert sanitize_identifier("variable_name") == "variable_name" + assert sanitize_identifier("VariableName") == "VariableName" + assert sanitize_identifier("variable!name") == "variablename" + assert sanitize_identifier("123variable") == "_123variable" + assert sanitize_identifier(None) == "_None_" + assert sanitize_identifier(None, when_none="Empty") == "Empty" + with pytest.raises(ValueError): + sanitize_identifier(None, when_none=None) + def test_dict_to_filename(): data = {"key1": "value1", "key2": "value2"} assert ( From c68499da95a8c530290dfde89d985eb4c2382909 Mon Sep 17 00:00:00 2001 From: mivanit Date: Mon, 17 Jun 2024 20:14:38 -0700 Subject: [PATCH 010/158] some fixes --- muutils/dictmagic.py | 3 ++- muutils/mlutils.py | 9 ++++++++- muutils/statcounter.py | 4 ++++ 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/muutils/dictmagic.py b/muutils/dictmagic.py index 555ddd74..21bb505f 100644 --- a/muutils/dictmagic.py +++ b/muutils/dictmagic.py @@ -178,7 +178,8 @@ def main(**kwargs): raise ValueError( f"when_unknown_prefix must be one of 'raise', 'warn', or 'ignore', got {when_unknown_prefix}" ) - key = key.removeprefix(strip_prefix) + else: + key = key[len(strip_prefix) :] if transform_key is not None: key = transform_key(key) diff --git a/muutils/mlutils.py b/muutils/mlutils.py index f094ce7d..f5dcffa8 100644 --- a/muutils/mlutils.py +++ b/muutils/mlutils.py @@ -136,8 +136,15 @@ def register_method( """Decorator to add a method to the method_dict""" def decorator(method: F) -> F: + method_name: str if custom_name is None: - method_name: str = method.__name__ + method_name: str|None = getattr(method, "__name__", None) + if method_name is None: + warnings.warn( + f"Method {method} does not have a name, using sanitized repr" + ) + from muutils.misc import sanitize_identifier + method_name = sanitize_identifier(repr(method)) else: method_name = custom_name method.__name__ = custom_name diff --git a/muutils/statcounter.py b/muutils/statcounter.py index d754758f..014bf6b8 100644 --- a/muutils/statcounter.py +++ b/muutils/statcounter.py @@ -54,6 +54,10 @@ def min(self): def max(self): return max(x for x, v in self.items() if v > 0) + + def total(self): + """Sum of the counts""" + return sum(self.values()) @cached_property def keys_sorted(self) -> list: From d50a38d1c9a7d4cff3fb9ac6727b17f291f5beb0 Mon Sep 17 00:00:00 2001 From: mivanit Date: Mon, 17 Jun 2024 20:29:17 -0700 Subject: [PATCH 011/158] wip --- muutils/__init__.py | 1 + muutils/json_serialize/__init__.py | 1 + muutils/json_serialize/json_serialize.py | 2 +- .../serializable_dataclass/test_helpers.py | 1 + tests/unit/test_mlutils.py | 9 +++++++-- 5 files changed, 11 insertions(+), 3 deletions(-) diff --git a/muutils/__init__.py b/muutils/__init__.py index e69de29b..6c43ea25 100644 --- a/muutils/__init__.py +++ b/muutils/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations \ No newline at end of file diff --git a/muutils/json_serialize/__init__.py b/muutils/json_serialize/__init__.py index 4582116d..8cfcdef7 100644 --- a/muutils/json_serialize/__init__.py +++ b/muutils/json_serialize/__init__.py @@ -1,3 +1,4 @@ +from __future__ import annotations from muutils.json_serialize.array import arr_metadata, load_array from muutils.json_serialize.json_serialize import ( BASE_HANDLERS, diff --git a/muutils/json_serialize/json_serialize.py b/muutils/json_serialize/json_serialize.py index b6097e0d..e362c2d3 100644 --- a/muutils/json_serialize/json_serialize.py +++ b/muutils/json_serialize/json_serialize.py @@ -100,7 +100,7 @@ def serialize(self) -> dict: BASE_HANDLERS: MonoTuple[SerializerHandler] = ( SerializerHandler( check=lambda self, obj, path: isinstance( - obj, (bool, int, float, str, types.NoneType) + obj, (bool, int, float, str, type(None)) ), serialize_func=lambda self, obj, path: obj, uid="base types", diff --git a/tests/unit/json_serialize/serializable_dataclass/test_helpers.py b/tests/unit/json_serialize/serializable_dataclass/test_helpers.py index 3184ba7f..f5fb8cb1 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_helpers.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_helpers.py @@ -1,3 +1,4 @@ +from __future__ import annotations from dataclasses import dataclass import numpy as np diff --git a/tests/unit/test_mlutils.py b/tests/unit/test_mlutils.py index a3dbca14..81e3f1ca 100644 --- a/tests/unit/test_mlutils.py +++ b/tests/unit/test_mlutils.py @@ -1,4 +1,5 @@ from pathlib import Path +import sys from muutils.mlutils import get_checkpoint_paths_for_run, register_method @@ -44,5 +45,9 @@ def other_eval_function(): evalsA = TestEvalsA.evals evalsB = TestEvalsB.evals - assert list(evalsA.keys()) == ["eval_function"] - assert list(evalsB.keys()) == ["other_eval_function"] + if sys.version_info >= (3, 9): + assert list(evalsA.keys()) == ["eval_function"] + assert list(evalsB.keys()) == ["other_eval_function"] + else: + assert len(evalsA) == 1 + assert len(evalsB) == 1 From 695aed171723781ef7e3ac8df761af2745404495 Mon Sep 17 00:00:00 2001 From: mivanit Date: Mon, 17 Jun 2024 20:29:39 -0700 Subject: [PATCH 012/158] fix formatting --- muutils/__init__.py | 2 +- muutils/dictmagic.py | 15 +++++- muutils/group_equiv.py | 1 + muutils/json_serialize/__init__.py | 1 + muutils/json_serialize/array.py | 1 + muutils/json_serialize/json_serialize.py | 6 +-- .../json_serialize/serializable_dataclass.py | 20 ++++---- muutils/json_serialize/util.py | 8 ++-- muutils/logger/headerfuncs.py | 1 + muutils/logger/logger.py | 1 + muutils/logger/loggingstream.py | 1 + muutils/logger/simplelogger.py | 1 + muutils/logger/timing.py | 1 + muutils/misc.py | 46 +++++++++++-------- muutils/mlutils.py | 6 ++- muutils/nbutils/configure_notebook.py | 1 + muutils/nbutils/convert_ipynb_to_script.py | 1 + muutils/statcounter.py | 3 +- muutils/sysinfo.py | 4 +- muutils/tensor_utils.py | 4 +- .../serializable_dataclass/test_helpers.py | 1 + .../test_sdc_defaults.py | 1 + .../test_sdc_properties_nested.py | 2 + .../test_serializable_dataclass.py | 1 + tests/unit/misc/test_freeze.py | 1 + tests/unit/misc/test_misc.py | 9 ++-- tests/unit/misc/test_numerical_conversions.py | 1 + tests/unit/nbutils/test_conversion.py | 1 + tests/unit/test_mlutils.py | 2 +- tests/unit/test_statcounter.py | 1 + tests/unit/test_tensor_utils.py | 1 + 31 files changed, 93 insertions(+), 52 deletions(-) diff --git a/muutils/__init__.py b/muutils/__init__.py index 6c43ea25..9d48db4f 100644 --- a/muutils/__init__.py +++ b/muutils/__init__.py @@ -1 +1 @@ -from __future__ import annotations \ No newline at end of file +from __future__ import annotations diff --git a/muutils/dictmagic.py b/muutils/dictmagic.py index 21bb505f..c17ec0fc 100644 --- a/muutils/dictmagic.py +++ b/muutils/dictmagic.py @@ -1,8 +1,21 @@ from __future__ import annotations + import typing import warnings from collections import defaultdict -from typing import Any, Callable, Generic, Hashable, Iterable, Literal, TypeVar, Dict, Union, Optional, Tuple +from typing import ( + Any, + Callable, + Dict, + Generic, + Hashable, + Iterable, + Literal, + Optional, + Tuple, + TypeVar, + Union, +) _KT = TypeVar("_KT") _VT = TypeVar("_VT") diff --git a/muutils/group_equiv.py b/muutils/group_equiv.py index fa50dadb..dc722235 100644 --- a/muutils/group_equiv.py +++ b/muutils/group_equiv.py @@ -1,4 +1,5 @@ from __future__ import annotations + from itertools import chain from typing import Callable, Sequence, TypeVar diff --git a/muutils/json_serialize/__init__.py b/muutils/json_serialize/__init__.py index 8cfcdef7..263ed273 100644 --- a/muutils/json_serialize/__init__.py +++ b/muutils/json_serialize/__init__.py @@ -1,4 +1,5 @@ from __future__ import annotations + from muutils.json_serialize.array import arr_metadata, load_array from muutils.json_serialize.json_serialize import ( BASE_HANDLERS, diff --git a/muutils/json_serialize/array.py b/muutils/json_serialize/array.py index 8130d160..866f6d43 100644 --- a/muutils/json_serialize/array.py +++ b/muutils/json_serialize/array.py @@ -1,4 +1,5 @@ from __future__ import annotations + import typing import warnings from typing import Any, Iterable, Literal, Optional, Sequence diff --git a/muutils/json_serialize/json_serialize.py b/muutils/json_serialize/json_serialize.py index e362c2d3..42d1f292 100644 --- a/muutils/json_serialize/json_serialize.py +++ b/muutils/json_serialize/json_serialize.py @@ -1,10 +1,10 @@ from __future__ import annotations + import inspect -import types import warnings from dataclasses import dataclass, is_dataclass from pathlib import Path -from typing import Any, Callable, Iterable, Mapping, Dict, Set, Union +from typing import Any, Callable, Dict, Iterable, Mapping, Set, Union try: from muutils.json_serialize.array import ArrayMode, serialize_array @@ -54,7 +54,7 @@ "", } -ObjectPath = MonoTuple[Union[str,int]] +ObjectPath = MonoTuple[Union[str, int]] @dataclass diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index f60e059d..2d60150b 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -1,4 +1,5 @@ from __future__ import annotations + import abc import dataclasses import json @@ -38,8 +39,7 @@ def __init__( self, default: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING, default_factory: Union[ - Callable[[], Any], - dataclasses._MISSING_TYPE + Callable[[], Any], dataclasses._MISSING_TYPE ] = dataclasses.MISSING, init: bool = True, repr: bool = True, @@ -74,11 +74,9 @@ def __init__( super_kwargs["metadata"] = types.MappingProxyType({}) # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy - if (sys.version_info[1] < 9): - if (super_kwargs["kw_only"] == True): # noqa: E712 - raise ValueError( - "kw_only is not supported in python >=3.9" - ) + if sys.version_info[1] < 9: + if super_kwargs["kw_only"] == True: # noqa: E712 + raise ValueError("kw_only is not supported in python >=3.9") else: del super_kwargs["kw_only"] @@ -433,12 +431,10 @@ def wrap(cls: Type[T]) -> Type[T]: setattr(cls, field_name, field_value) # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy - if (sys.version_info[1] < 9): + if sys.version_info[1] < 9: if "kw_only" in kwargs: - if (kwargs["kw_only"] == True): # noqa: E712 - raise ValueError( - "kw_only is not supported in python >=3.9" - ) + if kwargs["kw_only"] == True: # noqa: E712 + raise ValueError("kw_only is not supported in python >=3.9") else: del kwargs["kw_only"] diff --git a/muutils/json_serialize/util.py b/muutils/json_serialize/util.py index eedcde61..7787725b 100644 --- a/muutils/json_serialize/util.py +++ b/muutils/json_serialize/util.py @@ -1,11 +1,11 @@ from __future__ import annotations + import functools import inspect import sys -import types import typing import warnings -from typing import Any, Callable, Iterable, Literal, Union, Dict +from typing import Any, Callable, Dict, Iterable, Literal, Union _NUMPY_WORKING: bool try: @@ -54,9 +54,7 @@ def __class_getitem__(cls, params): elif len(params) == 1: return typing.GenericAlias(tuple, (params[0], Ellipsis)) else: - raise TypeError( - f"MonoTuple expects 1 type argument, got {params = }" - ) + raise TypeError(f"MonoTuple expects 1 type argument, got {params = }") class UniversalContainer: diff --git a/muutils/logger/headerfuncs.py b/muutils/logger/headerfuncs.py index 09257de9..8f327268 100644 --- a/muutils/logger/headerfuncs.py +++ b/muutils/logger/headerfuncs.py @@ -1,4 +1,5 @@ from __future__ import annotations + import json from typing import Any, Mapping, Protocol diff --git a/muutils/logger/logger.py b/muutils/logger/logger.py index d888cf97..20d46755 100644 --- a/muutils/logger/logger.py +++ b/muutils/logger/logger.py @@ -6,6 +6,7 @@ this was mostly made with training models in mind and storing both metadata and loss - `TimerContext` is a context manager that can be used to time the duration of a block of code """ + from __future__ import annotations import json diff --git a/muutils/logger/loggingstream.py b/muutils/logger/loggingstream.py index a28bdc8f..77cad982 100644 --- a/muutils/logger/loggingstream.py +++ b/muutils/logger/loggingstream.py @@ -1,4 +1,5 @@ from __future__ import annotations + import time from dataclasses import dataclass, field from typing import Any, Callable diff --git a/muutils/logger/simplelogger.py b/muutils/logger/simplelogger.py index 1f9adda8..07f1a306 100644 --- a/muutils/logger/simplelogger.py +++ b/muutils/logger/simplelogger.py @@ -1,4 +1,5 @@ from __future__ import annotations + import json import sys import time diff --git a/muutils/logger/timing.py b/muutils/logger/timing.py index 56102c55..c0c8e0c6 100644 --- a/muutils/logger/timing.py +++ b/muutils/logger/timing.py @@ -1,4 +1,5 @@ from __future__ import annotations + import time from typing import Literal diff --git a/muutils/misc.py b/muutils/misc.py index a83a930a..dd75ad4c 100644 --- a/muutils/misc.py +++ b/muutils/misc.py @@ -1,4 +1,5 @@ from __future__ import annotations + import hashlib import typing @@ -65,37 +66,37 @@ def list_join(lst: list, factory: typing.Callable) -> list: # name stuff # ================================================================================ + def sanitize_name( - name: str | None, - additional_allowed_chars: str = "", - replace_invalid: str = "", - when_none: str | None = "_None_", - leading_digit_prefix: str = "", - ) -> str: + name: str | None, + additional_allowed_chars: str = "", + replace_invalid: str = "", + when_none: str | None = "_None_", + leading_digit_prefix: str = "", +) -> str: """sanitize a string, leaving only alphanumerics and `additional_allowed_chars` - + # Parameters: - - `name : str | None` + - `name : str | None` input string - - `additional_allowed_chars : str` + - `additional_allowed_chars : str` additional characters to allow, none by default (defaults to `""`) - - `replace_invalid : str` + - `replace_invalid : str` character to replace invalid characters with (defaults to `""`) - - `when_none : str | None` + - `when_none : str | None` string to return if `name` is `None`. if `None`, raises an exception (defaults to `"_None_"`) - - `leading_digit_prefix : str` + - `leading_digit_prefix : str` character to prefix the string with if it starts with a digit (defaults to `""`) - + # Returns: - - `str` + - `str` sanitized string - """ + """ - if name is None: if when_none is None: raise ValueError("name is None") @@ -110,26 +111,31 @@ def sanitize_name( sanitized += char else: sanitized += replace_invalid - + if sanitized[0].isdigit(): sanitized = leading_digit_prefix + sanitized return sanitized + def sanitize_fname(fname: str | None, **kwargs) -> str: """sanitize a filename to posix standards - + - leave only alphanumerics, `_` (underscore), '-' (dash) and `.` (period) """ return sanitize_name(fname, additional_allowed_chars="._-", **kwargs) + def sanitize_identifier(fname: str | None, **kwargs) -> str: """sanitize an identifier (variable or function name) - + - leave only alphanumerics and `_` (underscore) - prefix with `_` if it starts with a digit """ - return sanitize_name(fname, additional_allowed_chars="_", leading_digit_prefix="_", **kwargs) + return sanitize_name( + fname, additional_allowed_chars="_", leading_digit_prefix="_", **kwargs + ) + def dict_to_filename( data: dict, diff --git a/muutils/mlutils.py b/muutils/mlutils.py index f5dcffa8..57fdc800 100644 --- a/muutils/mlutils.py +++ b/muutils/mlutils.py @@ -1,4 +1,5 @@ from __future__ import annotations + import json import os import random @@ -6,7 +7,7 @@ import warnings from itertools import islice from pathlib import Path -from typing import Any, Callable, TypeVar, Union, Optional +from typing import Any, Callable, Optional, TypeVar, Union ARRAY_IMPORTS: bool try: @@ -138,12 +139,13 @@ def register_method( def decorator(method: F) -> F: method_name: str if custom_name is None: - method_name: str|None = getattr(method, "__name__", None) + method_name: str | None = getattr(method, "__name__", None) if method_name is None: warnings.warn( f"Method {method} does not have a name, using sanitized repr" ) from muutils.misc import sanitize_identifier + method_name = sanitize_identifier(repr(method)) else: method_name = custom_name diff --git a/muutils/nbutils/configure_notebook.py b/muutils/nbutils/configure_notebook.py index 7cc80b00..e70116aa 100644 --- a/muutils/nbutils/configure_notebook.py +++ b/muutils/nbutils/configure_notebook.py @@ -1,4 +1,5 @@ from __future__ import annotations + import os import typing import warnings diff --git a/muutils/nbutils/convert_ipynb_to_script.py b/muutils/nbutils/convert_ipynb_to_script.py index 91eae765..8b65bcaf 100644 --- a/muutils/nbutils/convert_ipynb_to_script.py +++ b/muutils/nbutils/convert_ipynb_to_script.py @@ -1,4 +1,5 @@ from __future__ import annotations + import argparse import json import os diff --git a/muutils/statcounter.py b/muutils/statcounter.py index 014bf6b8..bce0b105 100644 --- a/muutils/statcounter.py +++ b/muutils/statcounter.py @@ -1,4 +1,5 @@ from __future__ import annotations + import json import math from collections import Counter @@ -54,7 +55,7 @@ def min(self): def max(self): return max(x for x, v in self.items() if v > 0) - + def total(self): """Sum of the counts""" return sum(self.values()) diff --git a/muutils/sysinfo.py b/muutils/sysinfo.py index 543a4c85..5a4dd8b2 100644 --- a/muutils/sysinfo.py +++ b/muutils/sysinfo.py @@ -6,7 +6,9 @@ from pip._internal.operations.freeze import freeze as pip_freeze -def _popen(cmd: typing.List[str], split_out: bool = False) -> typing.Dict[str, typing.Any]: +def _popen( + cmd: typing.List[str], split_out: bool = False +) -> typing.Dict[str, typing.Any]: p: subprocess.Popen = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) diff --git a/muutils/tensor_utils.py b/muutils/tensor_utils.py index 4c0a50f9..6be949dc 100644 --- a/muutils/tensor_utils.py +++ b/muutils/tensor_utils.py @@ -355,7 +355,9 @@ def rpad_array( return pad_array(array, pad_length, pad_value, rpad=True) -def get_dict_shapes(d: typing.Dict[str, "torch.Tensor"]) -> typing.Dict[str, typing.Tuple[int, ...]]: +def get_dict_shapes( + d: typing.Dict[str, "torch.Tensor"] +) -> typing.Dict[str, typing.Tuple[int, ...]]: """given a state dict or cache dict, compute the shapes and put them in a nested dict""" return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()}) diff --git a/tests/unit/json_serialize/serializable_dataclass/test_helpers.py b/tests/unit/json_serialize/serializable_dataclass/test_helpers.py index f5fb8cb1..81c92b2d 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_helpers.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_helpers.py @@ -1,4 +1,5 @@ from __future__ import annotations + from dataclasses import dataclass import numpy as np diff --git a/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py b/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py index ef373bf3..47d373e0 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py @@ -1,4 +1,5 @@ from typing import Dict + from muutils.json_serialize import ( JsonSerializer, SerializableDataclass, diff --git a/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py b/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py index 3e22fac7..8600c3c8 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py @@ -1,4 +1,5 @@ from __future__ import annotations + import sys import pytest @@ -9,6 +10,7 @@ print(f"{SUPPORS_KW_ONLY = }") + @serializable_dataclass class Person(SerializableDataclass): first_name: str diff --git a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py index 56423380..3f754f60 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py @@ -1,4 +1,5 @@ from __future__ import annotations + from typing import Any import pytest diff --git a/tests/unit/misc/test_freeze.py b/tests/unit/misc/test_freeze.py index 238ef6f9..819e1478 100644 --- a/tests/unit/misc/test_freeze.py +++ b/tests/unit/misc/test_freeze.py @@ -1,4 +1,5 @@ from __future__ import annotations + import pytest from muutils.misc import freeze diff --git a/tests/unit/misc/test_misc.py b/tests/unit/misc/test_misc.py index c752f725..ac5c4302 100644 --- a/tests/unit/misc/test_misc.py +++ b/tests/unit/misc/test_misc.py @@ -1,5 +1,5 @@ from __future__ import annotations -import pytest + import pytest from muutils.misc import ( @@ -7,9 +7,9 @@ freeze, list_join, list_split, - sanitize_name, sanitize_fname, sanitize_identifier, + sanitize_name, stable_hash, ) @@ -37,8 +37,6 @@ def test_sanitize_fname(): assert sanitize_fname(None) == "_None_", "None input should return '_None_'" - - def test_sanitize_name(): assert sanitize_name("Hello World") == "HelloWorld" assert sanitize_name("Hello_World", additional_allowed_chars="_") == "Hello_World" @@ -50,6 +48,7 @@ def test_sanitize_name(): assert sanitize_name("123abc") == "123abc" assert sanitize_name("123abc", leading_digit_prefix="_") == "_123abc" + def test_sanitize_fname_2(): assert sanitize_fname("file name.txt") == "filename.txt" assert sanitize_fname("file_name.txt") == "file_name.txt" @@ -62,6 +61,7 @@ def test_sanitize_fname_2(): assert sanitize_fname("123file.txt") == "123file.txt" assert sanitize_fname("123file.txt", leading_digit_prefix="_") == "_123file.txt" + def test_sanitize_identifier(): assert sanitize_identifier("variable_name") == "variable_name" assert sanitize_identifier("VariableName") == "VariableName" @@ -72,6 +72,7 @@ def test_sanitize_identifier(): with pytest.raises(ValueError): sanitize_identifier(None, when_none=None) + def test_dict_to_filename(): data = {"key1": "value1", "key2": "value2"} assert ( diff --git a/tests/unit/misc/test_numerical_conversions.py b/tests/unit/misc/test_numerical_conversions.py index f2175f2e..15c39ba5 100644 --- a/tests/unit/misc/test_numerical_conversions.py +++ b/tests/unit/misc/test_numerical_conversions.py @@ -1,4 +1,5 @@ from __future__ import annotations + import random from math import isclose, isinf, isnan diff --git a/tests/unit/nbutils/test_conversion.py b/tests/unit/nbutils/test_conversion.py index 8301f2da..f984432f 100644 --- a/tests/unit/nbutils/test_conversion.py +++ b/tests/unit/nbutils/test_conversion.py @@ -1,4 +1,5 @@ from __future__ import annotations + import itertools import os diff --git a/tests/unit/test_mlutils.py b/tests/unit/test_mlutils.py index 81e3f1ca..3b71847e 100644 --- a/tests/unit/test_mlutils.py +++ b/tests/unit/test_mlutils.py @@ -1,5 +1,5 @@ -from pathlib import Path import sys +from pathlib import Path from muutils.mlutils import get_checkpoint_paths_for_run, register_method diff --git a/tests/unit/test_statcounter.py b/tests/unit/test_statcounter.py index d373cf95..27f6ebae 100644 --- a/tests/unit/test_statcounter.py +++ b/tests/unit/test_statcounter.py @@ -1,4 +1,5 @@ from __future__ import annotations + import numpy as np from muutils.statcounter import StatCounter diff --git a/tests/unit/test_tensor_utils.py b/tests/unit/test_tensor_utils.py index 17c4ed15..f61c4f1e 100644 --- a/tests/unit/test_tensor_utils.py +++ b/tests/unit/test_tensor_utils.py @@ -1,4 +1,5 @@ from __future__ import annotations + import jaxtyping import numpy as np import pytest From 66c99b7289d85106c20ed62dd6fd2955a2307807 Mon Sep 17 00:00:00 2001 From: mivanit Date: Mon, 17 Jun 2024 20:35:23 -0700 Subject: [PATCH 013/158] fix mro --- muutils/json_serialize/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/muutils/json_serialize/util.py b/muutils/json_serialize/util.py index 7787725b..8de1d7d9 100644 --- a/muutils/json_serialize/util.py +++ b/muutils/json_serialize/util.py @@ -43,7 +43,7 @@ def __init_subclass__(cls, *args, **kwargs): def __class_getitem__(cls, params): if isinstance(params, type): typing.GenericAlias(tuple, (params, Ellipsis)) - elif any("typing.UnionType" in str(t) for t in params.mro()): + elif any("typing.UnionType" in str(t) for t in params.__class__.__mro__): # TODO: unsure about this # check via mro return typing.GenericAlias(tuple, (params, Ellipsis)) From 5584855a5909445cbc9cbcc312af439bd7770c67 Mon Sep 17 00:00:00 2001 From: mivanit Date: Mon, 17 Jun 2024 20:46:27 -0700 Subject: [PATCH 014/158] fix monotuple --- muutils/json_serialize/util.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/muutils/json_serialize/util.py b/muutils/json_serialize/util.py index 8de1d7d9..45174909 100644 --- a/muutils/json_serialize/util.py +++ b/muutils/json_serialize/util.py @@ -41,12 +41,12 @@ def __init_subclass__(cls, *args, **kwargs): # idk why mypy thinks there is no such function in typing @typing._tp_cache # type: ignore def __class_getitem__(cls, params): - if isinstance(params, type): - typing.GenericAlias(tuple, (params, Ellipsis)) - elif any("typing.UnionType" in str(t) for t in params.__class__.__mro__): + if any("typing.UnionType" in str(t) for t in params.__class__.__mro__): # TODO: unsure about this # check via mro return typing.GenericAlias(tuple, (params, Ellipsis)) + elif isinstance(params, type): + typing.GenericAlias(tuple, (params, Ellipsis)) # test if has len and is iterable elif isinstance(params, Iterable): if len(params) == 0: From 4be72749d44d99874a9fd3e7299f7a17701d405e Mon Sep 17 00:00:00 2001 From: mivanit Date: Mon, 17 Jun 2024 21:05:50 -0700 Subject: [PATCH 015/158] added base64 array mode (similar to existing hex mode) --- muutils/json_serialize/array.py | 26 ++++++++++++++++++++++--- tests/unit/json_serialize/test_array.py | 5 +++-- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/muutils/json_serialize/array.py b/muutils/json_serialize/array.py index 866f6d43..15f0646e 100644 --- a/muutils/json_serialize/array.py +++ b/muutils/json_serialize/array.py @@ -1,5 +1,6 @@ from __future__ import annotations +import base64 import typing import warnings from typing import Any, Iterable, Literal, Optional, Sequence @@ -10,7 +11,7 @@ # pylint: disable=unused-argument -ArrayMode = Literal["list", "array_list_meta", "array_hex_meta", "external", "zero_dim"] +ArrayMode = Literal["list", "array_list_meta", "array_hex_meta", "array_b64_meta", "external", "zero_dim"] def array_n_elements(arr) -> int: # type: ignore[name-defined] @@ -48,14 +49,15 @@ def serialize_array( - `list`: serialize as a list of values, no metadata (equivalent to `arr.tolist()`) - `array_list_meta`: serialize dict with metadata, actual list under the key `data` - `array_hex_meta`: serialize dict with metadata, actual hex string under the key `data` + - `array_b64_meta`: serialize dict with metadata, actual base64 string under the key `data` - for `array_list_meta` and `array_hex_meta`, the output will look like + for `array_list_meta`, `array_hex_meta`, and `array_b64_meta`, the serialized object is: ``` { "__format__": , "shape": arr.shape, "dtype": str(arr.dtype), - "data": , + "data": , } ``` @@ -100,6 +102,12 @@ def serialize_array( "data": arr_np.tobytes().hex(), **arr_metadata(arr_np), } + elif array_mode == "array_b64_meta": + return { + "__format__": f"{arr_type}:array_b64_meta", + "data": base64.b64encode(arr_np.tobytes()).decode(), + **arr_metadata(arr_np), + } else: raise KeyError(f"invalid array_mode: {array_mode}") @@ -119,6 +127,10 @@ def infer_array_mode(arr: JSONitem) -> ArrayMode: if not isinstance(arr["data"], str): raise ValueError(f"invalid hex format: {type(arr['data']) = }\t{arr}") return "array_hex_meta" + elif fmt.endswith(":array_b64_meta"): + if not isinstance(arr["data"], str): + raise ValueError(f"invalid b64 format: {type(arr['data']) = }\t{arr}") + return "array_b64_meta" elif fmt.endswith(":external"): return "external" elif fmt.endswith(":zero_dim"): @@ -164,6 +176,14 @@ def load_array(arr: JSONitem, array_mode: Optional[ArrayMode] = None) -> Any: data = np.frombuffer(bytes.fromhex(arr["data"]), dtype=arr["dtype"]) return data.reshape(arr["shape"]) + + elif array_mode == "array_b64_meta": + assert isinstance( + arr, typing.Mapping + ), f"invalid list format: {type(arr) = }\n{arr = }" + + data = np.frombuffer(base64.b64decode(arr["data"]), dtype=arr["dtype"]) + return data.reshape(arr["shape"]) elif array_mode == "list": assert isinstance( diff --git a/tests/unit/json_serialize/test_array.py b/tests/unit/json_serialize/test_array.py index dc3d767a..d921fa1f 100644 --- a/tests/unit/json_serialize/test_array.py +++ b/tests/unit/json_serialize/test_array.py @@ -39,6 +39,7 @@ def test_arr_metadata(self): ("list", list), ("array_list_meta", dict), ("array_hex_meta", dict), + ("array_b64_meta", dict), ], ) def test_serialize_array(self, array_mode: ArrayMode, expected_type: type): @@ -55,7 +56,7 @@ def test_load_array(self): assert np.array_equal(loaded_array, self.array_3d) def test_serialize_load_integration(self): - for array_mode in ["list", "array_list_meta", "array_hex_meta"]: + for array_mode in ["list", "array_list_meta", "array_hex_meta", "array_b64_meta"]: for array in [self.array_1d, self.array_2d, self.array_3d]: serialized_array = serialize_array( self.jser, array, "test_path", array_mode=array_mode @@ -64,7 +65,7 @@ def test_serialize_load_integration(self): assert np.array_equal(loaded_array, array) def test_serialize_load_zero_dim(self): - for array_mode in ["list", "array_list_meta", "array_hex_meta"]: + for array_mode in ["list", "array_list_meta", "array_hex_meta", "array_b64_meta"]: serialized_array = serialize_array( self.jser, self.array_zero_dim, "test_path", array_mode=array_mode ) From 4100cf489e4c26c48c7d7ba556788f8ba0ebd486 Mon Sep 17 00:00:00 2001 From: mivanit Date: Mon, 17 Jun 2024 21:12:12 -0700 Subject: [PATCH 016/158] remove py3.12 for now --- .github/workflows/checks.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 77d9f635..37ca9bc9 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -40,8 +40,8 @@ jobs: torch: "2.3.1" - python: "3.11" torch: "2.3.1" - - python: "3.12" - torch: "2.3.1" + # - python: "3.12" + # torch: "2.3.1" steps: - name: Checkout code uses: actions/checkout@v2 From 7c301c56c2b86dc9d5641fc0e43c74556c21ef60 Mon Sep 17 00:00:00 2001 From: mivanit Date: Mon, 17 Jun 2024 21:12:53 -0700 Subject: [PATCH 017/158] run format --- muutils/json_serialize/array.py | 11 +++++++++-- tests/unit/json_serialize/test_array.py | 14 ++++++++++++-- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/muutils/json_serialize/array.py b/muutils/json_serialize/array.py index 15f0646e..4a810f0c 100644 --- a/muutils/json_serialize/array.py +++ b/muutils/json_serialize/array.py @@ -11,7 +11,14 @@ # pylint: disable=unused-argument -ArrayMode = Literal["list", "array_list_meta", "array_hex_meta", "array_b64_meta", "external", "zero_dim"] +ArrayMode = Literal[ + "list", + "array_list_meta", + "array_hex_meta", + "array_b64_meta", + "external", + "zero_dim", +] def array_n_elements(arr) -> int: # type: ignore[name-defined] @@ -176,7 +183,7 @@ def load_array(arr: JSONitem, array_mode: Optional[ArrayMode] = None) -> Any: data = np.frombuffer(bytes.fromhex(arr["data"]), dtype=arr["dtype"]) return data.reshape(arr["shape"]) - + elif array_mode == "array_b64_meta": assert isinstance( arr, typing.Mapping diff --git a/tests/unit/json_serialize/test_array.py b/tests/unit/json_serialize/test_array.py index d921fa1f..91e9b445 100644 --- a/tests/unit/json_serialize/test_array.py +++ b/tests/unit/json_serialize/test_array.py @@ -56,7 +56,12 @@ def test_load_array(self): assert np.array_equal(loaded_array, self.array_3d) def test_serialize_load_integration(self): - for array_mode in ["list", "array_list_meta", "array_hex_meta", "array_b64_meta"]: + for array_mode in [ + "list", + "array_list_meta", + "array_hex_meta", + "array_b64_meta", + ]: for array in [self.array_1d, self.array_2d, self.array_3d]: serialized_array = serialize_array( self.jser, array, "test_path", array_mode=array_mode @@ -65,7 +70,12 @@ def test_serialize_load_integration(self): assert np.array_equal(loaded_array, array) def test_serialize_load_zero_dim(self): - for array_mode in ["list", "array_list_meta", "array_hex_meta", "array_b64_meta"]: + for array_mode in [ + "list", + "array_list_meta", + "array_hex_meta", + "array_b64_meta", + ]: serialized_array = serialize_array( self.jser, self.array_zero_dim, "test_path", array_mode=array_mode ) From 07cac42488df242ce57c41fcaf8c9e051529a0c8 Mon Sep 17 00:00:00 2001 From: mivanit Date: Mon, 17 Jun 2024 23:37:56 -0700 Subject: [PATCH 018/158] hacky fix for failing to get type hints in python 3.8 - cant get around `typing.get_type_hints` failing on `dict[str, int]` style types, event with `typing-extensions` - added some warnings and error checks if the python version is old - got confused by my own code but commented what I could - added some TODO comments - went down some rabbit holes, learned a lot about python :) - that was an adventure --- .../json_serialize/serializable_dataclass.py | 63 +++- muutils/json_serialize/util.py | 6 +- test.ipynb | 270 ++++++++++++++++++ .../test_serializable_dataclass.py | 32 ++- 4 files changed, 361 insertions(+), 10 deletions(-) create mode 100644 test.ipynb diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index 2d60150b..f68aa948 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -9,6 +9,7 @@ import warnings from typing import Any, Callable, Optional, Type, TypeVar, Union + # pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access @@ -52,6 +53,7 @@ def __init__( serialization_fn: Optional[Callable[[Any], Any]] = None, loading_fn: Optional[Callable[[Any], Any]] = None, assert_type: bool = True, + # TODO: add field for custom type assertion ): # TODO: should we do this check, or assume the user knows what they are doing? if init and not serialize: @@ -396,6 +398,7 @@ def __deepcopy__(self, memo: dict) -> "SerializableDataclass": # Step 3: Create a custom serializable_dataclass decorator +# TODO: add a kwarg for always asserting type for all fields def serializable_dataclass( # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it _cls=None, # type: ignore @@ -504,36 +507,88 @@ def load(cls, data: dict[str, Any] | T) -> Type[T]: data, typing.Mapping ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" - cls_type_hints: dict[str, Any] = typing.get_type_hints(cls) + # get the type hints for the class + cls_type_hints: dict[str, Any] + try: + cls_type_hints = typing.get_type_hints(cls) + except TypeError as e: + if sys.version_info < (3, 9): + warnings.warn( + f"Cannot get type hints for {cls.__name__}. Python version is {sys.version_info = }. You can:\n" + + " - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x)\n" + + " - use python 3.9.x or higher\n" + + " - add explicit loading functions to the fields\n" + + f" {dataclasses.fields(cls) = }" + ) + cls_type_hints = dict() + else: + raise TypeError( + f"Cannot get type hints for {cls.__name__}. Python version is {sys.version_info = }\n" + + f" {dataclasses.fields(cls) = }\n" + + f" {e = }" + ) from e + + # initialize dict for keeping what we will pass to the constructor ctor_kwargs: dict[str, Any] = dict() + + # iterate over the fields of the class for field in dataclasses.fields(cls): + # check if the field is a SerializableField assert isinstance( field, SerializableField - ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)} this state should be inaccessible, please report this bug!" + ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" + # check if the field is in the data and if it should be initialized if (field.name in data) and field.init: - value = data[field.name] + # get the value, we will be processing it + value: Any = data[field.name] + # get the type hint for the field field_type_hint: Any = cls_type_hints.get(field.name, None) + if field.loading_fn: + # if it has a loading function, use that value = field.loading_fn(data) elif ( field_type_hint is not None and hasattr(field_type_hint, "load") and callable(field_type_hint.load) ): + # if no loading function but has a type hint with a load method, use that if isinstance(value, dict): + # TODO: should this be passing the whole data dict? value = field_type_hint.load(value) else: raise ValueError( f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" ) + else: + # assume no loading needs to happen, keep `value` as-is + pass + # validate the type if field.assert_type: if field.name in ctor_kwargs: - assert isinstance(ctor_kwargs[field.name], field_type_hint) + if field_type_hint is not None: + # TODO: recursive type hint checking like pydantic? + assert isinstance(ctor_kwargs[field.name], field_type_hint) + else: + raise ValueError( + f"Cannot get type hints for {cls.__name__}, and so cannot validate. Python version is {sys.version_info = }. You can:\n" + + f" - disable `assert_type`. Currently: {field.assert_type = }\n" + + f" - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {field.type = }\n" + + " - use python 3.9.x or higher\n" + + " - coming in a future release, specify custom type validation functions\n" + ) + else: + # TODO: raise an exception here? Can't validate if no type hint given + warnings.warn( + f"Field '{field.name}' on class {cls} has no type hint, but {field.assert_type = }\n{field = }\n{cls_type_hints = }\n{data = }" + ) + # store the value in the constructor kwargs ctor_kwargs[field.name] = value + return cls(**ctor_kwargs) # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments diff --git a/muutils/json_serialize/util.py b/muutils/json_serialize/util.py index 45174909..287675a2 100644 --- a/muutils/json_serialize/util.py +++ b/muutils/json_serialize/util.py @@ -23,7 +23,7 @@ Hashableitem = Union[bool, int, float, str, tuple] # or if python version <3.9 -if typing.TYPE_CHECKING or sys.version_info[1] < 9: +if typing.TYPE_CHECKING or sys.version_info < (3, 9): MonoTuple = typing.Sequence else: @@ -41,9 +41,7 @@ def __init_subclass__(cls, *args, **kwargs): # idk why mypy thinks there is no such function in typing @typing._tp_cache # type: ignore def __class_getitem__(cls, params): - if any("typing.UnionType" in str(t) for t in params.__class__.__mro__): - # TODO: unsure about this - # check via mro + if getattr(params, "__origin__", None) == typing.Union: return typing.GenericAlias(tuple, (params, Ellipsis)) elif isinstance(params, type): typing.GenericAlias(tuple, (params, Ellipsis)) diff --git a/test.ipynb b/test.ipynb new file mode 100644 index 00000000..2372f691 --- /dev/null +++ b/test.ipynb @@ -0,0 +1,270 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "'type' object is not subscriptable", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32mf:\\projects\\tools\\muutils\\test.ipynb Cell 1\u001b[0m line \u001b[0;36m1\n\u001b[1;32m----> 1\u001b[0m x: \u001b[39mdict\u001b[39;49m[\u001b[39mstr\u001b[39;49m,\u001b[39mint\u001b[39;49m]\n", + "\u001b[1;31mTypeError\u001b[0m: 'type' object is not subscriptable" + ] + } + ], + "source": [ + "x: dict[str,int]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from __future__ import annotations\n", + "x: dict[str,int]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "def f(x: dict[str,int]) -> list[str]:\n", + " return list(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "'type' object is not subscriptable", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32mf:\\projects\\tools\\muutils\\test.ipynb Cell 4\u001b[0m line \u001b[0;36m2\n\u001b[0;32m 1\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mtyping\u001b[39;00m\n\u001b[1;32m----> 2\u001b[0m typing\u001b[39m.\u001b[39;49mget_type_hints(f)\n", + "File \u001b[1;32mc:\\Python\\Python3_8\\lib\\typing.py:1264\u001b[0m, in \u001b[0;36mget_type_hints\u001b[1;34m(obj, globalns, localns)\u001b[0m\n\u001b[0;32m 1262\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(value, \u001b[39mstr\u001b[39m):\n\u001b[0;32m 1263\u001b[0m value \u001b[39m=\u001b[39m ForwardRef(value)\n\u001b[1;32m-> 1264\u001b[0m value \u001b[39m=\u001b[39m _eval_type(value, globalns, localns)\n\u001b[0;32m 1265\u001b[0m \u001b[39mif\u001b[39;00m name \u001b[39min\u001b[39;00m defaults \u001b[39mand\u001b[39;00m defaults[name] \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m 1266\u001b[0m value \u001b[39m=\u001b[39m Optional[value]\n", + "File \u001b[1;32mc:\\Python\\Python3_8\\lib\\typing.py:270\u001b[0m, in \u001b[0;36m_eval_type\u001b[1;34m(t, globalns, localns)\u001b[0m\n\u001b[0;32m 266\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"Evaluate all forward references in the given type t.\u001b[39;00m\n\u001b[0;32m 267\u001b[0m \u001b[39mFor use of globalns and localns see the docstring for get_type_hints().\u001b[39;00m\n\u001b[0;32m 268\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[0;32m 269\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(t, ForwardRef):\n\u001b[1;32m--> 270\u001b[0m \u001b[39mreturn\u001b[39;00m t\u001b[39m.\u001b[39;49m_evaluate(globalns, localns)\n\u001b[0;32m 271\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(t, _GenericAlias):\n\u001b[0;32m 272\u001b[0m ev_args \u001b[39m=\u001b[39m \u001b[39mtuple\u001b[39m(_eval_type(a, globalns, localns) \u001b[39mfor\u001b[39;00m a \u001b[39min\u001b[39;00m t\u001b[39m.\u001b[39m__args__)\n", + "File \u001b[1;32mc:\\Python\\Python3_8\\lib\\typing.py:518\u001b[0m, in \u001b[0;36mForwardRef._evaluate\u001b[1;34m(self, globalns, localns)\u001b[0m\n\u001b[0;32m 515\u001b[0m \u001b[39melif\u001b[39;00m localns \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m 516\u001b[0m localns \u001b[39m=\u001b[39m globalns\n\u001b[0;32m 517\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__forward_value__ \u001b[39m=\u001b[39m _type_check(\n\u001b[1;32m--> 518\u001b[0m \u001b[39meval\u001b[39;49m(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m__forward_code__, globalns, localns),\n\u001b[0;32m 519\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mForward references must evaluate to types.\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[0;32m 520\u001b[0m is_argument\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__forward_is_argument__)\n\u001b[0;32m 521\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__forward_evaluated__ \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n\u001b[0;32m 522\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__forward_value__\n", + "File \u001b[1;32m:1\u001b[0m\n", + "\u001b[1;31mTypeError\u001b[0m: 'type' object is not subscriptable" + ] + } + ], + "source": [ + "import typing\n", + "typing.get_type_hints(f)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "'type' object is not subscriptable", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32mf:\\projects\\tools\\muutils\\test.ipynb Cell 5\u001b[0m line \u001b[0;36m2\n\u001b[0;32m 1\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mtyping_extensions\u001b[39;00m\n\u001b[1;32m----> 2\u001b[0m typing_extensions\u001b[39m.\u001b[39;49mget_type_hints(f)\n", + "File \u001b[1;32mc:\\Python\\Python3_8\\lib\\site-packages\\typing_extensions.py:1272\u001b[0m, in \u001b[0;36mget_type_hints\u001b[1;34m(obj, globalns, localns, include_extras)\u001b[0m\n\u001b[0;32m 1241\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mget_type_hints\u001b[39m(obj, globalns\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m, localns\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m, include_extras\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m):\n\u001b[0;32m 1242\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"Return type hints for an object.\u001b[39;00m\n\u001b[0;32m 1243\u001b[0m \n\u001b[0;32m 1244\u001b[0m \u001b[39m This is often the same as obj.__annotations__, but it handles\u001b[39;00m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 1270\u001b[0m \u001b[39m locals, respectively.\u001b[39;00m\n\u001b[0;32m 1271\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[1;32m-> 1272\u001b[0m hint \u001b[39m=\u001b[39m typing\u001b[39m.\u001b[39;49mget_type_hints(obj, globalns\u001b[39m=\u001b[39;49mglobalns, localns\u001b[39m=\u001b[39;49mlocalns)\n\u001b[0;32m 1273\u001b[0m \u001b[39mif\u001b[39;00m include_extras:\n\u001b[0;32m 1274\u001b[0m \u001b[39mreturn\u001b[39;00m hint\n", + "File \u001b[1;32mc:\\Python\\Python3_8\\lib\\typing.py:1264\u001b[0m, in \u001b[0;36mget_type_hints\u001b[1;34m(obj, globalns, localns)\u001b[0m\n\u001b[0;32m 1262\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(value, \u001b[39mstr\u001b[39m):\n\u001b[0;32m 1263\u001b[0m value \u001b[39m=\u001b[39m ForwardRef(value)\n\u001b[1;32m-> 1264\u001b[0m value \u001b[39m=\u001b[39m _eval_type(value, globalns, localns)\n\u001b[0;32m 1265\u001b[0m \u001b[39mif\u001b[39;00m name \u001b[39min\u001b[39;00m defaults \u001b[39mand\u001b[39;00m defaults[name] \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m 1266\u001b[0m value \u001b[39m=\u001b[39m Optional[value]\n", + "File \u001b[1;32mc:\\Python\\Python3_8\\lib\\typing.py:270\u001b[0m, in \u001b[0;36m_eval_type\u001b[1;34m(t, globalns, localns)\u001b[0m\n\u001b[0;32m 266\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"Evaluate all forward references in the given type t.\u001b[39;00m\n\u001b[0;32m 267\u001b[0m \u001b[39mFor use of globalns and localns see the docstring for get_type_hints().\u001b[39;00m\n\u001b[0;32m 268\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[0;32m 269\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(t, ForwardRef):\n\u001b[1;32m--> 270\u001b[0m \u001b[39mreturn\u001b[39;00m t\u001b[39m.\u001b[39;49m_evaluate(globalns, localns)\n\u001b[0;32m 271\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(t, _GenericAlias):\n\u001b[0;32m 272\u001b[0m ev_args \u001b[39m=\u001b[39m \u001b[39mtuple\u001b[39m(_eval_type(a, globalns, localns) \u001b[39mfor\u001b[39;00m a \u001b[39min\u001b[39;00m t\u001b[39m.\u001b[39m__args__)\n", + "File \u001b[1;32mc:\\Python\\Python3_8\\lib\\typing.py:518\u001b[0m, in \u001b[0;36mForwardRef._evaluate\u001b[1;34m(self, globalns, localns)\u001b[0m\n\u001b[0;32m 515\u001b[0m \u001b[39melif\u001b[39;00m localns \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m 516\u001b[0m localns \u001b[39m=\u001b[39m globalns\n\u001b[0;32m 517\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__forward_value__ \u001b[39m=\u001b[39m _type_check(\n\u001b[1;32m--> 518\u001b[0m \u001b[39meval\u001b[39;49m(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m__forward_code__, globalns, localns),\n\u001b[0;32m 519\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mForward references must evaluate to types.\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[0;32m 520\u001b[0m is_argument\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__forward_is_argument__)\n\u001b[0;32m 521\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__forward_evaluated__ \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n\u001b[0;32m 522\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__forward_value__\n", + "File \u001b[1;32m:1\u001b[0m\n", + "\u001b[1;31mTypeError\u001b[0m: 'type' object is not subscriptable" + ] + } + ], + "source": [ + "import typing_extensions\n", + "typing_extensions.get_type_hints(f)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Dict, List\n", + "\n", + "def f_t(x: Dict[str,int]) -> List[str]:\n", + " return list(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'x': typing.Dict[str, int], 'return': typing.List[str]}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "typing.get_type_hints(f_t)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(Field(name='x',type='typing.Dict[str, int]',default=,default_factory=,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),_field_type=_FIELD),\n", + " Field(name='y',type='list[str]',default=,default_factory=,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),_field_type=_FIELD),\n", + " Field(name='z',type='str',default=,default_factory=,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),_field_type=_FIELD))" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import dataclasses\n", + "\n", + "@dataclasses.dataclass\n", + "class Test:\n", + "\tx: typing.Dict[str,int]\n", + "\ty: list[str]\n", + "\tz: str\n", + "\n", + "dataclasses.fields(Test)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import pytest\n", + "import warnings\n", + "\n", + "with pytest.warns(UserWarning) as record:\n", + "\twarnings.warn(\"test\", UserWarning)\n", + "\twarnings.warn(\"test2\", UserWarning)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['_WARNING_DETAILS',\n", + " '__class__',\n", + " '__delattr__',\n", + " '__dict__',\n", + " '__dir__',\n", + " '__doc__',\n", + " '__eq__',\n", + " '__format__',\n", + " '__ge__',\n", + " '__getattribute__',\n", + " '__gt__',\n", + " '__hash__',\n", + " '__init__',\n", + " '__init_subclass__',\n", + " '__le__',\n", + " '__lt__',\n", + " '__module__',\n", + " '__ne__',\n", + " '__new__',\n", + " '__reduce__',\n", + " '__reduce_ex__',\n", + " '__repr__',\n", + " '__setattr__',\n", + " '__sizeof__',\n", + " '__str__',\n", + " '__subclasshook__',\n", + " '__weakref__',\n", + " '_category_name',\n", + " 'category',\n", + " 'file',\n", + " 'filename',\n", + " 'line',\n", + " 'lineno',\n", + " 'message',\n", + " 'source']" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dir(record[0])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "UserWarning('test')" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "record[0].message" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py index 3f754f60..26dfe931 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import Any +import sys import pytest @@ -13,6 +14,8 @@ # pylint: disable=missing-class-docstring, unused-variable +BELOW_PY_3_9: bool = sys.version_info < (3, 9) + @serializable_dataclass class BasicAutofields(SerializableDataclass): a: str @@ -107,6 +110,16 @@ def test_simple_fields_serialization(simple_fields_instance): def test_simple_fields_loading(simple_fields_instance): serialized = simple_fields_instance.serialize() + + if BELOW_PY_3_9: + with pytest.warns(UserWarning) as record: + loaded = SimpleFields.load(serialized) + print([x.message for x in record]) + assert len(record) == 4 + else: + loaded = SimpleFields.load(serialized) + + loaded = SimpleFields.load(serialized) assert loaded == simple_fields_instance assert loaded.diff(simple_fields_instance) == {} @@ -262,7 +275,14 @@ def full_name(self) -> str: } assert serialized == expected_ser, f"Expected {expected_ser}, got {serialized}" - loaded = FullPerson.load(serialized) + if BELOW_PY_3_9: + with pytest.warns(UserWarning) as record: + loaded = FullPerson.load(serialized) + print([x.message for x in record]) + assert len(record) == 4 + else: + loaded = FullPerson.load(serialized) + assert loaded == person @@ -327,7 +347,15 @@ def test_nested_with_container(): } assert serialized == expected_ser - loaded = Nested_with_Container.load(serialized) + + if BELOW_PY_3_9: + with pytest.warns(UserWarning) as record: + loaded = Nested_with_Container.load(serialized) + print([x.message for x in record]) + assert len(record) == 12 + else: + loaded = Nested_with_Container.load(serialized) + assert loaded == instance From ce12d94cc523451b18c285e5b41f013b37ca4fea Mon Sep 17 00:00:00 2001 From: mivanit Date: Mon, 17 Jun 2024 23:40:39 -0700 Subject: [PATCH 019/158] deleted test notebook it was junk, never should have been in the repo, but it's there in the commit history to remind me of what exactly goes wrong just in case --- test.ipynb | 270 ----------------------------------------------------- 1 file changed, 270 deletions(-) delete mode 100644 test.ipynb diff --git a/test.ipynb b/test.ipynb deleted file mode 100644 index 2372f691..00000000 --- a/test.ipynb +++ /dev/null @@ -1,270 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "ename": "TypeError", - "evalue": "'type' object is not subscriptable", - "output_type": "error", - "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32mf:\\projects\\tools\\muutils\\test.ipynb Cell 1\u001b[0m line \u001b[0;36m1\n\u001b[1;32m----> 1\u001b[0m x: \u001b[39mdict\u001b[39;49m[\u001b[39mstr\u001b[39;49m,\u001b[39mint\u001b[39;49m]\n", - "\u001b[1;31mTypeError\u001b[0m: 'type' object is not subscriptable" - ] - } - ], - "source": [ - "x: dict[str,int]" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "from __future__ import annotations\n", - "x: dict[str,int]" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "def f(x: dict[str,int]) -> list[str]:\n", - " return list(x)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "ename": "TypeError", - "evalue": "'type' object is not subscriptable", - "output_type": "error", - "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32mf:\\projects\\tools\\muutils\\test.ipynb Cell 4\u001b[0m line \u001b[0;36m2\n\u001b[0;32m 1\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mtyping\u001b[39;00m\n\u001b[1;32m----> 2\u001b[0m typing\u001b[39m.\u001b[39;49mget_type_hints(f)\n", - "File \u001b[1;32mc:\\Python\\Python3_8\\lib\\typing.py:1264\u001b[0m, in \u001b[0;36mget_type_hints\u001b[1;34m(obj, globalns, localns)\u001b[0m\n\u001b[0;32m 1262\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(value, \u001b[39mstr\u001b[39m):\n\u001b[0;32m 1263\u001b[0m value \u001b[39m=\u001b[39m ForwardRef(value)\n\u001b[1;32m-> 1264\u001b[0m value \u001b[39m=\u001b[39m _eval_type(value, globalns, localns)\n\u001b[0;32m 1265\u001b[0m \u001b[39mif\u001b[39;00m name \u001b[39min\u001b[39;00m defaults \u001b[39mand\u001b[39;00m defaults[name] \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m 1266\u001b[0m value \u001b[39m=\u001b[39m Optional[value]\n", - "File \u001b[1;32mc:\\Python\\Python3_8\\lib\\typing.py:270\u001b[0m, in \u001b[0;36m_eval_type\u001b[1;34m(t, globalns, localns)\u001b[0m\n\u001b[0;32m 266\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"Evaluate all forward references in the given type t.\u001b[39;00m\n\u001b[0;32m 267\u001b[0m \u001b[39mFor use of globalns and localns see the docstring for get_type_hints().\u001b[39;00m\n\u001b[0;32m 268\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[0;32m 269\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(t, ForwardRef):\n\u001b[1;32m--> 270\u001b[0m \u001b[39mreturn\u001b[39;00m t\u001b[39m.\u001b[39;49m_evaluate(globalns, localns)\n\u001b[0;32m 271\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(t, _GenericAlias):\n\u001b[0;32m 272\u001b[0m ev_args \u001b[39m=\u001b[39m \u001b[39mtuple\u001b[39m(_eval_type(a, globalns, localns) \u001b[39mfor\u001b[39;00m a \u001b[39min\u001b[39;00m t\u001b[39m.\u001b[39m__args__)\n", - "File \u001b[1;32mc:\\Python\\Python3_8\\lib\\typing.py:518\u001b[0m, in \u001b[0;36mForwardRef._evaluate\u001b[1;34m(self, globalns, localns)\u001b[0m\n\u001b[0;32m 515\u001b[0m \u001b[39melif\u001b[39;00m localns \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m 516\u001b[0m localns \u001b[39m=\u001b[39m globalns\n\u001b[0;32m 517\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__forward_value__ \u001b[39m=\u001b[39m _type_check(\n\u001b[1;32m--> 518\u001b[0m \u001b[39meval\u001b[39;49m(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m__forward_code__, globalns, localns),\n\u001b[0;32m 519\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mForward references must evaluate to types.\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[0;32m 520\u001b[0m is_argument\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__forward_is_argument__)\n\u001b[0;32m 521\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__forward_evaluated__ \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n\u001b[0;32m 522\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__forward_value__\n", - "File \u001b[1;32m:1\u001b[0m\n", - "\u001b[1;31mTypeError\u001b[0m: 'type' object is not subscriptable" - ] - } - ], - "source": [ - "import typing\n", - "typing.get_type_hints(f)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "ename": "TypeError", - "evalue": "'type' object is not subscriptable", - "output_type": "error", - "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32mf:\\projects\\tools\\muutils\\test.ipynb Cell 5\u001b[0m line \u001b[0;36m2\n\u001b[0;32m 1\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mtyping_extensions\u001b[39;00m\n\u001b[1;32m----> 2\u001b[0m typing_extensions\u001b[39m.\u001b[39;49mget_type_hints(f)\n", - "File \u001b[1;32mc:\\Python\\Python3_8\\lib\\site-packages\\typing_extensions.py:1272\u001b[0m, in \u001b[0;36mget_type_hints\u001b[1;34m(obj, globalns, localns, include_extras)\u001b[0m\n\u001b[0;32m 1241\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mget_type_hints\u001b[39m(obj, globalns\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m, localns\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m, include_extras\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m):\n\u001b[0;32m 1242\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"Return type hints for an object.\u001b[39;00m\n\u001b[0;32m 1243\u001b[0m \n\u001b[0;32m 1244\u001b[0m \u001b[39m This is often the same as obj.__annotations__, but it handles\u001b[39;00m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 1270\u001b[0m \u001b[39m locals, respectively.\u001b[39;00m\n\u001b[0;32m 1271\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[1;32m-> 1272\u001b[0m hint \u001b[39m=\u001b[39m typing\u001b[39m.\u001b[39;49mget_type_hints(obj, globalns\u001b[39m=\u001b[39;49mglobalns, localns\u001b[39m=\u001b[39;49mlocalns)\n\u001b[0;32m 1273\u001b[0m \u001b[39mif\u001b[39;00m include_extras:\n\u001b[0;32m 1274\u001b[0m \u001b[39mreturn\u001b[39;00m hint\n", - "File \u001b[1;32mc:\\Python\\Python3_8\\lib\\typing.py:1264\u001b[0m, in \u001b[0;36mget_type_hints\u001b[1;34m(obj, globalns, localns)\u001b[0m\n\u001b[0;32m 1262\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(value, \u001b[39mstr\u001b[39m):\n\u001b[0;32m 1263\u001b[0m value \u001b[39m=\u001b[39m ForwardRef(value)\n\u001b[1;32m-> 1264\u001b[0m value \u001b[39m=\u001b[39m _eval_type(value, globalns, localns)\n\u001b[0;32m 1265\u001b[0m \u001b[39mif\u001b[39;00m name \u001b[39min\u001b[39;00m defaults \u001b[39mand\u001b[39;00m defaults[name] \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m 1266\u001b[0m value \u001b[39m=\u001b[39m Optional[value]\n", - "File \u001b[1;32mc:\\Python\\Python3_8\\lib\\typing.py:270\u001b[0m, in \u001b[0;36m_eval_type\u001b[1;34m(t, globalns, localns)\u001b[0m\n\u001b[0;32m 266\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"Evaluate all forward references in the given type t.\u001b[39;00m\n\u001b[0;32m 267\u001b[0m \u001b[39mFor use of globalns and localns see the docstring for get_type_hints().\u001b[39;00m\n\u001b[0;32m 268\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[0;32m 269\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(t, ForwardRef):\n\u001b[1;32m--> 270\u001b[0m \u001b[39mreturn\u001b[39;00m t\u001b[39m.\u001b[39;49m_evaluate(globalns, localns)\n\u001b[0;32m 271\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(t, _GenericAlias):\n\u001b[0;32m 272\u001b[0m ev_args \u001b[39m=\u001b[39m \u001b[39mtuple\u001b[39m(_eval_type(a, globalns, localns) \u001b[39mfor\u001b[39;00m a \u001b[39min\u001b[39;00m t\u001b[39m.\u001b[39m__args__)\n", - "File \u001b[1;32mc:\\Python\\Python3_8\\lib\\typing.py:518\u001b[0m, in \u001b[0;36mForwardRef._evaluate\u001b[1;34m(self, globalns, localns)\u001b[0m\n\u001b[0;32m 515\u001b[0m \u001b[39melif\u001b[39;00m localns \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m 516\u001b[0m localns \u001b[39m=\u001b[39m globalns\n\u001b[0;32m 517\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__forward_value__ \u001b[39m=\u001b[39m _type_check(\n\u001b[1;32m--> 518\u001b[0m \u001b[39meval\u001b[39;49m(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m__forward_code__, globalns, localns),\n\u001b[0;32m 519\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mForward references must evaluate to types.\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[0;32m 520\u001b[0m is_argument\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__forward_is_argument__)\n\u001b[0;32m 521\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__forward_evaluated__ \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n\u001b[0;32m 522\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__forward_value__\n", - "File \u001b[1;32m:1\u001b[0m\n", - "\u001b[1;31mTypeError\u001b[0m: 'type' object is not subscriptable" - ] - } - ], - "source": [ - "import typing_extensions\n", - "typing_extensions.get_type_hints(f)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "from typing import Dict, List\n", - "\n", - "def f_t(x: Dict[str,int]) -> List[str]:\n", - " return list(x)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'x': typing.Dict[str, int], 'return': typing.List[str]}" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "typing.get_type_hints(f_t)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(Field(name='x',type='typing.Dict[str, int]',default=,default_factory=,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),_field_type=_FIELD),\n", - " Field(name='y',type='list[str]',default=,default_factory=,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),_field_type=_FIELD),\n", - " Field(name='z',type='str',default=,default_factory=,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),_field_type=_FIELD))" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import dataclasses\n", - "\n", - "@dataclasses.dataclass\n", - "class Test:\n", - "\tx: typing.Dict[str,int]\n", - "\ty: list[str]\n", - "\tz: str\n", - "\n", - "dataclasses.fields(Test)" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import pytest\n", - "import warnings\n", - "\n", - "with pytest.warns(UserWarning) as record:\n", - "\twarnings.warn(\"test\", UserWarning)\n", - "\twarnings.warn(\"test2\", UserWarning)" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['_WARNING_DETAILS',\n", - " '__class__',\n", - " '__delattr__',\n", - " '__dict__',\n", - " '__dir__',\n", - " '__doc__',\n", - " '__eq__',\n", - " '__format__',\n", - " '__ge__',\n", - " '__getattribute__',\n", - " '__gt__',\n", - " '__hash__',\n", - " '__init__',\n", - " '__init_subclass__',\n", - " '__le__',\n", - " '__lt__',\n", - " '__module__',\n", - " '__ne__',\n", - " '__new__',\n", - " '__reduce__',\n", - " '__reduce_ex__',\n", - " '__repr__',\n", - " '__setattr__',\n", - " '__sizeof__',\n", - " '__str__',\n", - " '__subclasshook__',\n", - " '__weakref__',\n", - " '_category_name',\n", - " 'category',\n", - " 'file',\n", - " 'filename',\n", - " 'line',\n", - " 'lineno',\n", - " 'message',\n", - " 'source']" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dir(record[0])\n" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "UserWarning('test')" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "record[0].message" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.10" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} From 80f8aafd3c721090ef4607df74ed0aa0d3ec02dc Mon Sep 17 00:00:00 2001 From: mivanit Date: Mon, 17 Jun 2024 23:54:39 -0700 Subject: [PATCH 020/158] fix more warnings --- .../test_sdc_defaults.py | 29 ++++++++-- .../test_sdc_properties_nested.py | 21 ++++++- .../test_serializable_dataclass.py | 56 +++++++++---------- 3 files changed, 70 insertions(+), 36 deletions(-) diff --git a/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py b/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py index 47d373e0..1f878efc 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py @@ -1,4 +1,8 @@ -from typing import Dict +from __future__ import annotations +import sys +from typing import Dict, Any + +import pytest from muutils.json_serialize import ( JsonSerializer, @@ -9,6 +13,23 @@ # pylint: disable=missing-class-docstring +BELOW_PY_3_9: bool = sys.version_info < (3, 9) + + +def _loading_test_wrapper(cls, data, assert_record_len: int|None = None) -> Any: + """wrapper for testing the load function, which accounts for version differences""" + if BELOW_PY_3_9: + with pytest.warns(UserWarning) as record: + loaded = cls.load(data) + print([x.message for x in record]) + if assert_record_len is not None: + assert len(record) == assert_record_len + return loaded + else: + loaded = cls.load(data) + return loaded + + @serializable_dataclass class Config(SerializableDataclass): @@ -26,7 +47,7 @@ def test_sdc_empty(): "batch_size": 64, "__format__": "Config(SerializableDataclass)", } - recovered = Config.load(serialized) + recovered = _loading_test_wrapper(Config, serialized) assert recovered == instance @@ -40,7 +61,7 @@ def test_sdc_strip_format_jser(): "batch_size": 64, "__write_format__": "Config(SerializableDataclass)", } - recovered = Config.load(serialized) + recovered = _loading_test_wrapper(Config, serialized) assert recovered == instance @@ -63,5 +84,5 @@ class ComplicatedConfig(SerializableDataclass): def test_sdc_empty_complicated(): instance = ComplicatedConfig() serialized = instance.serialize() - recovered = ComplicatedConfig.load(serialized) + recovered = _loading_test_wrapper(ComplicatedConfig, serialized) assert recovered == instance diff --git a/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py b/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py index 8600c3c8..d652cd21 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py @@ -10,6 +10,23 @@ print(f"{SUPPORS_KW_ONLY = }") +BELOW_PY_3_9: bool = sys.version_info < (3, 9) + + +def _loading_test_wrapper(cls, data, assert_record_len: int|None = None) -> Any: + """wrapper for testing the load function, which accounts for version differences""" + if BELOW_PY_3_9: + with pytest.warns(UserWarning) as record: + loaded = cls.load(data) + print([x.message for x in record]) + if assert_record_len is not None: + assert len(record) == assert_record_len + return loaded + else: + loaded = cls.load(data) + return loaded + + @serializable_dataclass class Person(SerializableDataclass): @@ -43,7 +60,7 @@ def test_serialize_person(): "__format__": "Person(SerializableDataclass)", } - recovered = Person.load(serialized) + recovered = _loading_test_wrapper(Person, serialized) assert recovered == instance @@ -66,6 +83,6 @@ def test_serialize_titled_person(): "full_title": "Dr. Jane Smith", } - recovered = TitledPerson.load(serialized) + recovered = _loading_test_wrapper(TitledPerson, serialized) assert recovered == instance diff --git a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py index 26dfe931..acbdc807 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py @@ -16,6 +16,22 @@ BELOW_PY_3_9: bool = sys.version_info < (3, 9) + + +def _loading_test_wrapper(cls, data, assert_record_len: int|None = None) -> Any: + """wrapper for testing the load function, which accounts for version differences""" + if BELOW_PY_3_9: + with pytest.warns(UserWarning) as record: + loaded = cls.load(data) + print([x.message for x in record]) + if assert_record_len is not None: + assert len(record) == assert_record_len + return loaded + else: + loaded = cls.load(data) + return loaded + + @serializable_dataclass class BasicAutofields(SerializableDataclass): a: str @@ -111,16 +127,8 @@ def test_simple_fields_serialization(simple_fields_instance): def test_simple_fields_loading(simple_fields_instance): serialized = simple_fields_instance.serialize() - if BELOW_PY_3_9: - with pytest.warns(UserWarning) as record: - loaded = SimpleFields.load(serialized) - print([x.message for x in record]) - assert len(record) == 4 - else: - loaded = SimpleFields.load(serialized) - + loaded = _loading_test_wrapper(SimpleFields, serialized, assert_record_len=4) - loaded = SimpleFields.load(serialized) assert loaded == simple_fields_instance assert loaded.diff(simple_fields_instance) == {} assert simple_fields_instance.diff(loaded) == {} @@ -138,7 +146,7 @@ def test_field_options_serialization(field_options_instance): def test_field_options_loading(field_options_instance): serialized = field_options_instance.serialize() - loaded = FieldOptions.load(serialized) + loaded = _loading_test_wrapper(FieldOptions, serialized, assert_record_len=3) assert loaded == field_options_instance @@ -154,7 +162,7 @@ def test_with_property_serialization(with_property_instance): def test_with_property_loading(with_property_instance): serialized = with_property_instance.serialize() - loaded = WithProperty.load(serialized) + loaded = _loading_test_wrapper(WithProperty, serialized, assert_record_len=2) assert loaded == with_property_instance @@ -200,7 +208,7 @@ def test_nested_serialization(person_instance): def test_nested_loading(person_instance): serialized = person_instance.serialize() - loaded = Person.load(serialized) + loaded = _loading_test_wrapper(Person, serialized, assert_record_len=6) assert loaded == person_instance assert loaded.address == person_instance.address @@ -223,7 +231,7 @@ def full_name(self) -> str: serialized_data = my_instance.serialize() print(serialized_data) - loaded_instance = MyClass.load(serialized_data) + loaded_instance = _loading_test_wrapper(MyClass, serialized_data, assert_record_len=3) print(loaded_instance) @@ -241,7 +249,7 @@ class SimpleClass(SerializableDataclass): "__format__": "SimpleClass(SerializableDataclass)", } - loaded = SimpleClass.load(serialized) + loaded = _loading_test_wrapper(SimpleClass, serialized, assert_record_len=2) assert loaded == simple @@ -275,13 +283,7 @@ def full_name(self) -> str: } assert serialized == expected_ser, f"Expected {expected_ser}, got {serialized}" - if BELOW_PY_3_9: - with pytest.warns(UserWarning) as record: - loaded = FullPerson.load(serialized) - print([x.message for x in record]) - assert len(record) == 4 - else: - loaded = FullPerson.load(serialized) + loaded = _loading_test_wrapper(FullPerson, serialized, assert_record_len=4) assert loaded == person @@ -300,7 +302,7 @@ class CustomSerialization(SerializableDataclass): "__format__": "CustomSerialization(SerializableDataclass)", } - loaded = CustomSerialization.load(serialized) + loaded = _loading_test_wrapper(CustomSerialization, serialized, assert_record_len=1) assert loaded == custom @@ -348,13 +350,7 @@ def test_nested_with_container(): assert serialized == expected_ser - if BELOW_PY_3_9: - with pytest.warns(UserWarning) as record: - loaded = Nested_with_Container.load(serialized) - print([x.message for x in record]) - assert len(record) == 12 - else: - loaded = Nested_with_Container.load(serialized) + loaded = _loading_test_wrapper(Nested_with_Container, serialized, assert_record_len=12) assert loaded == instance @@ -394,5 +390,5 @@ def test_nested_custom(): "__format__": "nested_custom(SerializableDataclass)", } assert serialized == expected_ser - loaded = nested_custom.load(serialized) + loaded = _loading_test_wrapper(nested_custom, serialized) assert loaded == instance From 400bc2e3154a7fc533c79697555654f6d79968ef Mon Sep 17 00:00:00 2001 From: mivanit Date: Mon, 17 Jun 2024 23:54:56 -0700 Subject: [PATCH 021/158] format --- .../json_serialize/serializable_dataclass.py | 5 +++-- .../serializable_dataclass/test_sdc_defaults.py | 6 +++--- .../test_sdc_properties_nested.py | 3 +-- .../test_serializable_dataclass.py | 17 ++++++++++------- 4 files changed, 17 insertions(+), 14 deletions(-) diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index f68aa948..18203c67 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -9,7 +9,6 @@ import warnings from typing import Any, Callable, Optional, Type, TypeVar, Union - # pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access @@ -571,7 +570,9 @@ def load(cls, data: dict[str, Any] | T) -> Type[T]: if field.name in ctor_kwargs: if field_type_hint is not None: # TODO: recursive type hint checking like pydantic? - assert isinstance(ctor_kwargs[field.name], field_type_hint) + assert isinstance( + ctor_kwargs[field.name], field_type_hint + ) else: raise ValueError( f"Cannot get type hints for {cls.__name__}, and so cannot validate. Python version is {sys.version_info = }. You can:\n" diff --git a/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py b/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py index 1f878efc..031ac429 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py @@ -1,6 +1,7 @@ from __future__ import annotations + import sys -from typing import Dict, Any +from typing import Any, Dict import pytest @@ -16,7 +17,7 @@ BELOW_PY_3_9: bool = sys.version_info < (3, 9) -def _loading_test_wrapper(cls, data, assert_record_len: int|None = None) -> Any: +def _loading_test_wrapper(cls, data, assert_record_len: int | None = None) -> Any: """wrapper for testing the load function, which accounts for version differences""" if BELOW_PY_3_9: with pytest.warns(UserWarning) as record: @@ -30,7 +31,6 @@ def _loading_test_wrapper(cls, data, assert_record_len: int|None = None) -> Any: return loaded - @serializable_dataclass class Config(SerializableDataclass): name: str = serializable_field(default="default_name") diff --git a/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py b/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py index d652cd21..339f0544 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py @@ -13,7 +13,7 @@ BELOW_PY_3_9: bool = sys.version_info < (3, 9) -def _loading_test_wrapper(cls, data, assert_record_len: int|None = None) -> Any: +def _loading_test_wrapper(cls, data, assert_record_len: int | None = None) -> Any: """wrapper for testing the load function, which accounts for version differences""" if BELOW_PY_3_9: with pytest.warns(UserWarning) as record: @@ -27,7 +27,6 @@ def _loading_test_wrapper(cls, data, assert_record_len: int|None = None) -> Any: return loaded - @serializable_dataclass class Person(SerializableDataclass): first_name: str diff --git a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py index acbdc807..c140e232 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py @@ -1,7 +1,7 @@ from __future__ import annotations -from typing import Any import sys +from typing import Any import pytest @@ -17,8 +17,7 @@ BELOW_PY_3_9: bool = sys.version_info < (3, 9) - -def _loading_test_wrapper(cls, data, assert_record_len: int|None = None) -> Any: +def _loading_test_wrapper(cls, data, assert_record_len: int | None = None) -> Any: """wrapper for testing the load function, which accounts for version differences""" if BELOW_PY_3_9: with pytest.warns(UserWarning) as record: @@ -231,7 +230,9 @@ def full_name(self) -> str: serialized_data = my_instance.serialize() print(serialized_data) - loaded_instance = _loading_test_wrapper(MyClass, serialized_data, assert_record_len=3) + loaded_instance = _loading_test_wrapper( + MyClass, serialized_data, assert_record_len=3 + ) print(loaded_instance) @@ -284,7 +285,7 @@ def full_name(self) -> str: assert serialized == expected_ser, f"Expected {expected_ser}, got {serialized}" loaded = _loading_test_wrapper(FullPerson, serialized, assert_record_len=4) - + assert loaded == person @@ -350,8 +351,10 @@ def test_nested_with_container(): assert serialized == expected_ser - loaded = _loading_test_wrapper(Nested_with_Container, serialized, assert_record_len=12) - + loaded = _loading_test_wrapper( + Nested_with_Container, serialized, assert_record_len=12 + ) + assert loaded == instance From f366a11d738dd37d5bb6f9d49f437c6b30bb8dcb Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 00:05:28 -0700 Subject: [PATCH 022/158] fixing some warnings --- muutils/sysinfo.py | 1 + poetry.lock | 32 +++++++++++- pyproject.toml | 1 + .../test_sdc_properties_nested.py | 3 +- tests/unit/test_mlutils.py | 50 ++++++++++++------- tests/unit/test_tensor_utils.py | 2 +- 6 files changed, 67 insertions(+), 22 deletions(-) diff --git a/muutils/sysinfo.py b/muutils/sysinfo.py index 5a4dd8b2..736e74d3 100644 --- a/muutils/sysinfo.py +++ b/muutils/sysinfo.py @@ -38,6 +38,7 @@ def python() -> dict: ver_tup = sys.version_info return { "version": sys.version, + "version_info": ver_tup, "major": ver_tup[0], "minor": ver_tup[1], "micro": ver_tup[2], diff --git a/poetry.lock b/poetry.lock index 69b82465..6066a7d1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1524,6 +1524,21 @@ docs = ["furo (>=2023.9.10)", "proselint (>=0.13)", "sphinx (>=7.2.6)", "sphinx- test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)"] type = ["mypy (>=1.8)"] +[[package]] +name = "plotly" +version = "5.22.0" +description = "An open-source, interactive data visualization library for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "plotly-5.22.0-py3-none-any.whl", hash = "sha256:68fc1901f098daeb233cc3dd44ec9dc31fb3ca4f4e53189344199c43496ed006"}, + {file = "plotly-5.22.0.tar.gz", hash = "sha256:859fdadbd86b5770ae2466e542b761b247d1c6b49daed765b95bb8c7063e7469"}, +] + +[package.dependencies] +packaging = "*" +tenacity = ">=6.2.0" + [[package]] name = "pluggy" version = "1.5.0" @@ -1853,6 +1868,21 @@ files = [ {file = "tbb-2021.12.0-py3-none-win_amd64.whl", hash = "sha256:fc2772d850229f2f3df85f1109c4844c495a2db7433d38200959ee9265b34789"}, ] +[[package]] +name = "tenacity" +version = "8.4.1" +description = "Retry code until it succeeds" +optional = false +python-versions = ">=3.8" +files = [ + {file = "tenacity-8.4.1-py3-none-any.whl", hash = "sha256:28522e692eda3e1b8f5e99c51464efcc0b9fc86933da92415168bc1c4e2308fa"}, + {file = "tenacity-8.4.1.tar.gz", hash = "sha256:54b1412b878ddf7e1f1577cd49527bad8cdef32421bd599beac0c6c3f10582fd"}, +] + +[package.extras] +doc = ["reno", "sphinx"] +test = ["pytest", "tornado (>=4.5)", "typeguard"] + [[package]] name = "tomli" version = "2.0.1" @@ -2141,4 +2171,4 @@ notebook = ["ipython"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "dcc7222c78e1d3797c2b80513b86c27c44394206e1f98af1e5943c4b3f506cab" +content-hash = "215baa1a19852a9078a038f47b35d45b08f614d239bdca0570f5c9d6b494ecbf" diff --git a/pyproject.toml b/pyproject.toml index 681a3a46..1dc6b1e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ mypy = "^1.0.1" pytest-cov = "^4.1.0" coverage-badge = "^1.1.0" matplotlib = "^3.0.0" +plotly = "^5.0.0" [build-system] requires = ["poetry-core"] diff --git a/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py b/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py index 339f0544..35d9abb8 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py @@ -1,12 +1,13 @@ from __future__ import annotations import sys +from typing import Any import pytest from muutils.json_serialize import SerializableDataclass, serializable_dataclass -SUPPORS_KW_ONLY: bool = sys.version_info[1] >= 9 +SUPPORS_KW_ONLY: bool = sys.version_info >= (3, 10) print(f"{SUPPORS_KW_ONLY = }") diff --git a/tests/unit/test_mlutils.py b/tests/unit/test_mlutils.py index 3b71847e..a6b41410 100644 --- a/tests/unit/test_mlutils.py +++ b/tests/unit/test_mlutils.py @@ -1,6 +1,8 @@ import sys from pathlib import Path +import pytest + from muutils.mlutils import get_checkpoint_paths_for_run, register_method @@ -22,32 +24,42 @@ def test_get_checkpoint_paths_for_run(): assert checkpoint_paths == [(123, checkpoint1_path), (456, checkpoint2_path)] +BELOW_PY_3_9: bool = sys.version_info < (3, 9) + def test_register_method(): - class TestEvalsA: - evals = {} - @register_method(evals) - @staticmethod - def eval_function(): - pass + with pytest.warns(UserWarning) as record: + + class TestEvalsA: + evals = {} - @staticmethod - def other_function(): - pass + @register_method(evals) + @staticmethod + def eval_function(): + pass - class TestEvalsB: - evals = {} + @staticmethod + def other_function(): + pass - @register_method(evals) - @staticmethod - def other_eval_function(): - pass + class TestEvalsB: + evals = {} + + @register_method(evals) + @staticmethod + def other_eval_function(): + pass + + if BELOW_PY_3_9: + assert len(record) == 2 + else: + assert len(record) == 0 evalsA = TestEvalsA.evals evalsB = TestEvalsB.evals - if sys.version_info >= (3, 9): - assert list(evalsA.keys()) == ["eval_function"] - assert list(evalsB.keys()) == ["other_eval_function"] - else: + if BELOW_PY_3_9: assert len(evalsA) == 1 assert len(evalsB) == 1 + else: + assert list(evalsA.keys()) == ["eval_function"] + assert list(evalsB.keys()) == ["other_eval_function"] diff --git a/tests/unit/test_tensor_utils.py b/tests/unit/test_tensor_utils.py index f61c4f1e..5efe16b6 100644 --- a/tests/unit/test_tensor_utils.py +++ b/tests/unit/test_tensor_utils.py @@ -24,7 +24,7 @@ def test_jaxtype_factory(): - ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) + ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float, legacy_mode="ignore") assert ATensor.__name__ == "ATensor" assert "default_jax_dtype = " in ATensor.__doc__ From fa2839e645e779d1a29a1e59096e52ecc1168ebe Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 00:28:52 -0700 Subject: [PATCH 023/158] NO WARNINGS --- muutils/sysinfo.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/muutils/sysinfo.py b/muutils/sysinfo.py index 736e74d3..4e30298d 100644 --- a/muutils/sysinfo.py +++ b/muutils/sysinfo.py @@ -1,9 +1,10 @@ +from __future__ import annotations import os import subprocess import sys import typing -from pip._internal.operations.freeze import freeze as pip_freeze +from importlib.metadata import distributions def _popen( @@ -49,7 +50,10 @@ def python() -> dict: @staticmethod def pip() -> dict: """installed packages info""" - pckgs: typing.List[str] = [x for x in pip_freeze(local_only=True)] + pckgs: list[tuple[str, str]] = [ + (x.name, x.version) + for x in distributions() + ] return { "n_packages": len(pckgs), "packages": pckgs, @@ -141,7 +145,7 @@ def platform() -> dict: return {x: getattr(platform, x)() for x in items} @staticmethod - def git_info() -> dict: + def git_info(with_log: bool = False) -> dict: git_version: dict = _popen(["git", "version"]) git_status: dict = _popen(["git", "status"]) if git_status["stderr"].startswith("fatal: not a git repository"): @@ -150,13 +154,17 @@ def git_info() -> dict: "git status": git_status, } else: - return { + + output: dict[str, str] = { "git version": git_version["stdout"], "git status": git_status, "git branch": _popen(["git", "branch"], split_out=True), "git remote -v": _popen(["git", "remote", "-v"], split_out=True), - "git log": _popen(["git", "log"]), } + if with_log: + output["git log"] = _popen(["git", "log"], split_out=False) + + return output @classmethod def get_all( @@ -183,3 +191,8 @@ def get_all( ] ) } + + +if __name__ == "__main__": + import pprint + pprint.pprint(SysInfo.get_all()) \ No newline at end of file From 3be50f3d9fc453ed83bc142504e126c694d05676 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 00:29:47 -0700 Subject: [PATCH 024/158] format --- muutils/sysinfo.py | 10 ++++------ tests/unit/test_mlutils.py | 1 + tests/unit/test_tensor_utils.py | 4 +++- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/muutils/sysinfo.py b/muutils/sysinfo.py index 4e30298d..b5736f1c 100644 --- a/muutils/sysinfo.py +++ b/muutils/sysinfo.py @@ -1,9 +1,9 @@ from __future__ import annotations + import os import subprocess import sys import typing - from importlib.metadata import distributions @@ -50,10 +50,7 @@ def python() -> dict: @staticmethod def pip() -> dict: """installed packages info""" - pckgs: list[tuple[str, str]] = [ - (x.name, x.version) - for x in distributions() - ] + pckgs: list[tuple[str, str]] = [(x.name, x.version) for x in distributions()] return { "n_packages": len(pckgs), "packages": pckgs, @@ -195,4 +192,5 @@ def get_all( if __name__ == "__main__": import pprint - pprint.pprint(SysInfo.get_all()) \ No newline at end of file + + pprint.pprint(SysInfo.get_all()) diff --git a/tests/unit/test_mlutils.py b/tests/unit/test_mlutils.py index a6b41410..ad3f3798 100644 --- a/tests/unit/test_mlutils.py +++ b/tests/unit/test_mlutils.py @@ -26,6 +26,7 @@ def test_get_checkpoint_paths_for_run(): BELOW_PY_3_9: bool = sys.version_info < (3, 9) + def test_register_method(): with pytest.warns(UserWarning) as record: diff --git a/tests/unit/test_tensor_utils.py b/tests/unit/test_tensor_utils.py index 5efe16b6..77aa8000 100644 --- a/tests/unit/test_tensor_utils.py +++ b/tests/unit/test_tensor_utils.py @@ -24,7 +24,9 @@ def test_jaxtype_factory(): - ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float, legacy_mode="ignore") + ATensor = jaxtype_factory( + "ATensor", torch.Tensor, jaxtyping.Float, legacy_mode="ignore" + ) assert ATensor.__name__ == "ATensor" assert "default_jax_dtype = " in ATensor.__doc__ From feb9b0e66ab5295ed7d89c232c598b816bb4b91f Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 00:33:04 -0700 Subject: [PATCH 025/158] fix for not warning on python >=3.10 --- tests/unit/test_mlutils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_mlutils.py b/tests/unit/test_mlutils.py index ad3f3798..8deb2f52 100644 --- a/tests/unit/test_mlutils.py +++ b/tests/unit/test_mlutils.py @@ -29,7 +29,7 @@ def test_get_checkpoint_paths_for_run(): def test_register_method(): - with pytest.warns(UserWarning) as record: + with pytest.warns() as record: class TestEvalsA: evals = {} From 2d8cde8899c9181786a4bfc93813d476b8006bb4 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 00:34:20 -0700 Subject: [PATCH 026/158] fix kw_only args for py 3.9 --- muutils/json_serialize/serializable_dataclass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index 18203c67..986573a6 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -433,7 +433,7 @@ def wrap(cls: Type[T]) -> Type[T]: setattr(cls, field_name, field_value) # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy - if sys.version_info[1] < 9: + if sys.version_info < (3, 10): if "kw_only" in kwargs: if kwargs["kw_only"] == True: # noqa: E712 raise ValueError("kw_only is not supported in python >=3.9") From 6a81a6dc25c62b67ab9b807a59c259bef3199bc2 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 00:37:08 -0700 Subject: [PATCH 027/158] bump pytest dep --- poetry.lock | 16 ++++++++-------- pyproject.toml | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/poetry.lock b/poetry.lock index 6066a7d1..691fc25e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1670,13 +1670,13 @@ diagrams = ["jinja2", "railroad-diagrams"] [[package]] name = "pytest" -version = "7.4.4" +version = "8.2.2" description = "pytest: simple powerful testing with Python" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "pytest-7.4.4-py3-none-any.whl", hash = "sha256:b090cdf5ed60bf4c45261be03239c2c1c22df034fbffe691abe93cd80cea01d8"}, - {file = "pytest-7.4.4.tar.gz", hash = "sha256:2cf0005922c6ace4a3e2ec8b4080eb0d9753fdc93107415332f50ce9e7994280"}, + {file = "pytest-8.2.2-py3-none-any.whl", hash = "sha256:c434598117762e2bd304e526244f67bf66bbd7b5d6cf22138be51ff661980343"}, + {file = "pytest-8.2.2.tar.gz", hash = "sha256:de4bb8104e201939ccdc688b27a89a7be2079b22e2bd2b07f806b6ba71117977"}, ] [package.dependencies] @@ -1684,11 +1684,11 @@ colorama = {version = "*", markers = "sys_platform == \"win32\""} exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} iniconfig = "*" packaging = "*" -pluggy = ">=0.12,<2.0" -tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} +pluggy = ">=1.5,<2.0" +tomli = {version = ">=1", markers = "python_version < \"3.11\""} [package.extras] -testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] [[package]] name = "pytest-cov" @@ -2171,4 +2171,4 @@ notebook = ["ipython"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "215baa1a19852a9078a038f47b35d45b08f614d239bdca0570f5c9d6b494ecbf" +content-hash = "c2a126688ee5e43af857b80f2d7062ad7858ad4b95e114784e5ab595a780607b" diff --git a/pyproject.toml b/pyproject.toml index 1dc6b1e2..55a261de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ array = ["numpy", "torch", "jaxtyping"] notebook = ["ipython"] [tool.poetry.group.dev.dependencies] -pytest = "^7.2.1" +pytest = "^8.2.2" black = "^24.1.1" pylint = "^2.16.4" pycln = "^2.1.3" From 542a439d6e40e0fde56982260a63c9403b221767 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 00:42:21 -0700 Subject: [PATCH 028/158] wip --- .../serializable_dataclass/test_sdc_defaults.py | 4 ++-- .../serializable_dataclass/test_sdc_properties_nested.py | 4 ++-- .../serializable_dataclass/test_serializable_dataclass.py | 4 ++-- tests/unit/test_mlutils.py | 6 +++--- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py b/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py index 031ac429..91c01b72 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py @@ -14,12 +14,12 @@ # pylint: disable=missing-class-docstring -BELOW_PY_3_9: bool = sys.version_info < (3, 9) +BELOW_PY_3_10: bool = sys.version_info < (3, 10) def _loading_test_wrapper(cls, data, assert_record_len: int | None = None) -> Any: """wrapper for testing the load function, which accounts for version differences""" - if BELOW_PY_3_9: + if BELOW_PY_3_10: with pytest.warns(UserWarning) as record: loaded = cls.load(data) print([x.message for x in record]) diff --git a/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py b/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py index 35d9abb8..ac7c20a4 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py @@ -11,12 +11,12 @@ print(f"{SUPPORS_KW_ONLY = }") -BELOW_PY_3_9: bool = sys.version_info < (3, 9) +BELOW_PY_3_10: bool = sys.version_info < (3, 10) def _loading_test_wrapper(cls, data, assert_record_len: int | None = None) -> Any: """wrapper for testing the load function, which accounts for version differences""" - if BELOW_PY_3_9: + if BELOW_PY_3_10: with pytest.warns(UserWarning) as record: loaded = cls.load(data) print([x.message for x in record]) diff --git a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py index c140e232..3c43a11c 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py @@ -14,12 +14,12 @@ # pylint: disable=missing-class-docstring, unused-variable -BELOW_PY_3_9: bool = sys.version_info < (3, 9) +BELOW_PY_3_10: bool = sys.version_info < (3, 10) def _loading_test_wrapper(cls, data, assert_record_len: int | None = None) -> Any: """wrapper for testing the load function, which accounts for version differences""" - if BELOW_PY_3_9: + if BELOW_PY_3_10: with pytest.warns(UserWarning) as record: loaded = cls.load(data) print([x.message for x in record]) diff --git a/tests/unit/test_mlutils.py b/tests/unit/test_mlutils.py index 8deb2f52..72d7298d 100644 --- a/tests/unit/test_mlutils.py +++ b/tests/unit/test_mlutils.py @@ -24,7 +24,7 @@ def test_get_checkpoint_paths_for_run(): assert checkpoint_paths == [(123, checkpoint1_path), (456, checkpoint2_path)] -BELOW_PY_3_9: bool = sys.version_info < (3, 9) +BELOW_PY_3_10: bool = sys.version_info < (3, 9) def test_register_method(): @@ -51,14 +51,14 @@ class TestEvalsB: def other_eval_function(): pass - if BELOW_PY_3_9: + if BELOW_PY_3_10: assert len(record) == 2 else: assert len(record) == 0 evalsA = TestEvalsA.evals evalsB = TestEvalsB.evals - if BELOW_PY_3_9: + if BELOW_PY_3_10: assert len(evalsA) == 1 assert len(evalsB) == 1 else: From a93e18811e460f318902913182ae41bf8e7862bf Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 00:45:30 -0700 Subject: [PATCH 029/158] wip --- tests/unit/test_mlutils.py | 38 ++++++++++++++++++-------------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/tests/unit/test_mlutils.py b/tests/unit/test_mlutils.py index 72d7298d..8b09fe58 100644 --- a/tests/unit/test_mlutils.py +++ b/tests/unit/test_mlutils.py @@ -27,34 +27,32 @@ def test_get_checkpoint_paths_for_run(): BELOW_PY_3_10: bool = sys.version_info < (3, 9) -def test_register_method(): +def test_register_method(recwarn): - with pytest.warns() as record: + class TestEvalsA: + evals = {} - class TestEvalsA: - evals = {} + @register_method(evals) + @staticmethod + def eval_function(): + pass - @register_method(evals) - @staticmethod - def eval_function(): - pass + @staticmethod + def other_function(): + pass - @staticmethod - def other_function(): - pass + class TestEvalsB: + evals = {} - class TestEvalsB: - evals = {} - - @register_method(evals) - @staticmethod - def other_eval_function(): - pass + @register_method(evals) + @staticmethod + def other_eval_function(): + pass if BELOW_PY_3_10: - assert len(record) == 2 + assert len(recwarn) == 2 else: - assert len(record) == 0 + assert len(recwarn) == 0 evalsA = TestEvalsA.evals evalsB = TestEvalsB.evals From 681e691b1d0d304e2dc3c2b98084096a9f8a0769 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 00:46:21 -0700 Subject: [PATCH 030/158] format --- tests/unit/test_mlutils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/unit/test_mlutils.py b/tests/unit/test_mlutils.py index 8b09fe58..c094e9d9 100644 --- a/tests/unit/test_mlutils.py +++ b/tests/unit/test_mlutils.py @@ -1,8 +1,6 @@ import sys from pathlib import Path -import pytest - from muutils.mlutils import get_checkpoint_paths_for_run, register_method From 5706978f8e7a0499743d453182fe5b6d0701c9e5 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 00:49:27 -0700 Subject: [PATCH 031/158] wip --- muutils/json_serialize/serializable_dataclass.py | 2 +- .../serializable_dataclass/test_sdc_properties_nested.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index 986573a6..f2d98216 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -75,7 +75,7 @@ def __init__( super_kwargs["metadata"] = types.MappingProxyType({}) # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy - if sys.version_info[1] < 9: + if sys.version_info < (3, 10): if super_kwargs["kw_only"] == True: # noqa: E712 raise ValueError("kw_only is not supported in python >=3.9") else: diff --git a/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py b/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py index ac7c20a4..a1440bf6 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py @@ -7,9 +7,9 @@ from muutils.json_serialize import SerializableDataclass, serializable_dataclass -SUPPORS_KW_ONLY: bool = sys.version_info >= (3, 10) +SUPPORTS_KW_ONLY: bool = sys.version_info >= (3, 10) -print(f"{SUPPORS_KW_ONLY = }") +print(f"{SUPPORTS_KW_ONLY = }") BELOW_PY_3_10: bool = sys.version_info < (3, 10) @@ -39,7 +39,7 @@ def full_name(self) -> str: @serializable_dataclass( - kw_only=SUPPORS_KW_ONLY, properties_to_serialize=["full_name", "full_title"] + kw_only=SUPPORTS_KW_ONLY, properties_to_serialize=["full_name", "full_title"] ) class TitledPerson(Person): title: str @@ -68,7 +68,7 @@ def test_serialize_person(): def test_serialize_titled_person(): instance = TitledPerson(first_name="Jane", last_name="Smith", title="Dr.") - if SUPPORS_KW_ONLY: + if SUPPORTS_KW_ONLY: with pytest.raises(TypeError): TitledPerson("Jane", "Smith", "Dr.") From f7afc5ee12bd18fe541a81524f942406af2bcd91 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 00:50:35 -0700 Subject: [PATCH 032/158] no linting for py <3.10 --- .github/workflows/checks.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 37ca9bc9..5fae782e 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -65,4 +65,5 @@ jobs: run: make test - name: lint + if: "${{ matrix.versions.python >= '3.10' }}" run: make lint From a14a781d93fd22cf51e48af935ac7e7167095cca Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 00:53:44 -0700 Subject: [PATCH 033/158] asserting record len would need to be per-version --- .../test_sdc_defaults.py | 3 +-- .../test_serializable_dataclass.py | 18 +++++++++--------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py b/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py index 91c01b72..9bec3c55 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py @@ -23,8 +23,7 @@ def _loading_test_wrapper(cls, data, assert_record_len: int | None = None) -> An with pytest.warns(UserWarning) as record: loaded = cls.load(data) print([x.message for x in record]) - if assert_record_len is not None: - assert len(record) == assert_record_len + return loaded else: loaded = cls.load(data) diff --git a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py index 3c43a11c..7f329199 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py @@ -126,7 +126,7 @@ def test_simple_fields_serialization(simple_fields_instance): def test_simple_fields_loading(simple_fields_instance): serialized = simple_fields_instance.serialize() - loaded = _loading_test_wrapper(SimpleFields, serialized, assert_record_len=4) + loaded = _loading_test_wrapper(SimpleFields, serialized) #, assert_record_len=4) assert loaded == simple_fields_instance assert loaded.diff(simple_fields_instance) == {} @@ -145,7 +145,7 @@ def test_field_options_serialization(field_options_instance): def test_field_options_loading(field_options_instance): serialized = field_options_instance.serialize() - loaded = _loading_test_wrapper(FieldOptions, serialized, assert_record_len=3) + loaded = _loading_test_wrapper(FieldOptions, serialized) #, assert_record_len=3) assert loaded == field_options_instance @@ -161,7 +161,7 @@ def test_with_property_serialization(with_property_instance): def test_with_property_loading(with_property_instance): serialized = with_property_instance.serialize() - loaded = _loading_test_wrapper(WithProperty, serialized, assert_record_len=2) + loaded = _loading_test_wrapper(WithProperty, serialized) #, assert_record_len=2) assert loaded == with_property_instance @@ -207,7 +207,7 @@ def test_nested_serialization(person_instance): def test_nested_loading(person_instance): serialized = person_instance.serialize() - loaded = _loading_test_wrapper(Person, serialized, assert_record_len=6) + loaded = _loading_test_wrapper(Person, serialized) #, assert_record_len=6) assert loaded == person_instance assert loaded.address == person_instance.address @@ -231,7 +231,7 @@ def full_name(self) -> str: print(serialized_data) loaded_instance = _loading_test_wrapper( - MyClass, serialized_data, assert_record_len=3 + MyClass, serialized_data) #, assert_record_len=3 ) print(loaded_instance) @@ -250,7 +250,7 @@ class SimpleClass(SerializableDataclass): "__format__": "SimpleClass(SerializableDataclass)", } - loaded = _loading_test_wrapper(SimpleClass, serialized, assert_record_len=2) + loaded = _loading_test_wrapper(SimpleClass, serialized) #, assert_record_len=2) assert loaded == simple @@ -284,7 +284,7 @@ def full_name(self) -> str: } assert serialized == expected_ser, f"Expected {expected_ser}, got {serialized}" - loaded = _loading_test_wrapper(FullPerson, serialized, assert_record_len=4) + loaded = _loading_test_wrapper(FullPerson, serialized) #, assert_record_len=4) assert loaded == person @@ -303,7 +303,7 @@ class CustomSerialization(SerializableDataclass): "__format__": "CustomSerialization(SerializableDataclass)", } - loaded = _loading_test_wrapper(CustomSerialization, serialized, assert_record_len=1) + loaded = _loading_test_wrapper(CustomSerialization, serialized) #, assert_record_len=1) assert loaded == custom @@ -352,7 +352,7 @@ def test_nested_with_container(): assert serialized == expected_ser loaded = _loading_test_wrapper( - Nested_with_Container, serialized, assert_record_len=12 + Nested_with_Container, serialized #, assert_record_len=12 ) assert loaded == instance From c7701df68837c9dd3de7b41e0a58d17229df2c0b Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 00:54:31 -0700 Subject: [PATCH 034/158] format, syntax error --- .../test_sdc_defaults.py | 2 +- .../test_serializable_dataclass.py | 20 ++++++++++--------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py b/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py index 9bec3c55..b6d380d1 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py @@ -23,7 +23,7 @@ def _loading_test_wrapper(cls, data, assert_record_len: int | None = None) -> An with pytest.warns(UserWarning) as record: loaded = cls.load(data) print([x.message for x in record]) - + return loaded else: loaded = cls.load(data) diff --git a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py index 7f329199..124fe08e 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py @@ -126,7 +126,7 @@ def test_simple_fields_serialization(simple_fields_instance): def test_simple_fields_loading(simple_fields_instance): serialized = simple_fields_instance.serialize() - loaded = _loading_test_wrapper(SimpleFields, serialized) #, assert_record_len=4) + loaded = _loading_test_wrapper(SimpleFields, serialized) # , assert_record_len=4) assert loaded == simple_fields_instance assert loaded.diff(simple_fields_instance) == {} @@ -145,7 +145,7 @@ def test_field_options_serialization(field_options_instance): def test_field_options_loading(field_options_instance): serialized = field_options_instance.serialize() - loaded = _loading_test_wrapper(FieldOptions, serialized) #, assert_record_len=3) + loaded = _loading_test_wrapper(FieldOptions, serialized) # , assert_record_len=3) assert loaded == field_options_instance @@ -161,7 +161,7 @@ def test_with_property_serialization(with_property_instance): def test_with_property_loading(with_property_instance): serialized = with_property_instance.serialize() - loaded = _loading_test_wrapper(WithProperty, serialized) #, assert_record_len=2) + loaded = _loading_test_wrapper(WithProperty, serialized) # , assert_record_len=2) assert loaded == with_property_instance @@ -207,7 +207,7 @@ def test_nested_serialization(person_instance): def test_nested_loading(person_instance): serialized = person_instance.serialize() - loaded = _loading_test_wrapper(Person, serialized) #, assert_record_len=6) + loaded = _loading_test_wrapper(Person, serialized) # , assert_record_len=6) assert loaded == person_instance assert loaded.address == person_instance.address @@ -231,7 +231,7 @@ def full_name(self) -> str: print(serialized_data) loaded_instance = _loading_test_wrapper( - MyClass, serialized_data) #, assert_record_len=3 + MyClass, serialized_data # , assert_record_len=3 ) print(loaded_instance) @@ -250,7 +250,7 @@ class SimpleClass(SerializableDataclass): "__format__": "SimpleClass(SerializableDataclass)", } - loaded = _loading_test_wrapper(SimpleClass, serialized) #, assert_record_len=2) + loaded = _loading_test_wrapper(SimpleClass, serialized) # , assert_record_len=2) assert loaded == simple @@ -284,7 +284,7 @@ def full_name(self) -> str: } assert serialized == expected_ser, f"Expected {expected_ser}, got {serialized}" - loaded = _loading_test_wrapper(FullPerson, serialized) #, assert_record_len=4) + loaded = _loading_test_wrapper(FullPerson, serialized) # , assert_record_len=4) assert loaded == person @@ -303,7 +303,9 @@ class CustomSerialization(SerializableDataclass): "__format__": "CustomSerialization(SerializableDataclass)", } - loaded = _loading_test_wrapper(CustomSerialization, serialized) #, assert_record_len=1) + loaded = _loading_test_wrapper( + CustomSerialization, serialized + ) # , assert_record_len=1) assert loaded == custom @@ -352,7 +354,7 @@ def test_nested_with_container(): assert serialized == expected_ser loaded = _loading_test_wrapper( - Nested_with_Container, serialized #, assert_record_len=12 + Nested_with_Container, serialized # , assert_record_len=12 ) assert loaded == instance From 70665a9c722a8ad6f3fabc6f5f7b9819b5ef3086 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 00:55:47 -0700 Subject: [PATCH 035/158] ugh --- .../json_serialize/serializable_dataclass/test_sdc_defaults.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py b/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py index b6d380d1..91c01b72 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py @@ -23,7 +23,8 @@ def _loading_test_wrapper(cls, data, assert_record_len: int | None = None) -> An with pytest.warns(UserWarning) as record: loaded = cls.load(data) print([x.message for x in record]) - + if assert_record_len is not None: + assert len(record) == assert_record_len return loaded else: loaded = cls.load(data) From aa11e5b405687951ac2dbe68d774b93c2518ea5e Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 01:00:39 -0700 Subject: [PATCH 036/158] wip --- muutils/json_serialize/serializable_dataclass.py | 2 +- .../serializable_dataclass/test_serializable_dataclass.py | 2 +- tests/unit/test_mlutils.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index f2d98216..4dd4da30 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -582,7 +582,7 @@ def load(cls, data: dict[str, Any] | T) -> Type[T]: + " - coming in a future release, specify custom type validation functions\n" ) else: - # TODO: raise an exception here? Can't validate if no type hint given + # TODO: raise an exception here? Can't validate if data given warnings.warn( f"Field '{field.name}' on class {cls} has no type hint, but {field.assert_type = }\n{field = }\n{cls_type_hints = }\n{data = }" ) diff --git a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py index 124fe08e..31ecd1bc 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py @@ -384,7 +384,7 @@ class nested_custom(SerializableDataclass): data1: Custom_class_with_serialization -def test_nested_custom(): +def test_nested_custom(recwarn): # this will send some warnings but whatever instance = nested_custom( value=42.0, data1=Custom_class_with_serialization(1, "hello") ) diff --git a/tests/unit/test_mlutils.py b/tests/unit/test_mlutils.py index c094e9d9..fd91c761 100644 --- a/tests/unit/test_mlutils.py +++ b/tests/unit/test_mlutils.py @@ -22,7 +22,7 @@ def test_get_checkpoint_paths_for_run(): assert checkpoint_paths == [(123, checkpoint1_path), (456, checkpoint2_path)] -BELOW_PY_3_10: bool = sys.version_info < (3, 9) +BELOW_PY_3_10: bool = sys.version_info < (3, 10) def test_register_method(recwarn): From 57d5dd62788b3bf66c5de5c7b73b502ebf121ce5 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 01:03:46 -0700 Subject: [PATCH 037/158] fix linting --- .github/workflows/checks.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 5fae782e..c8337de1 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -65,5 +65,5 @@ jobs: run: make test - name: lint - if: "${{ matrix.versions.python >= '3.10' }}" + if: matrix.versions.python == '3.8' || matrix.versions.python == '3.9' run: make lint From 254c0f11cf2f890f4654a443ccb8968643430b0f Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 01:07:24 -0700 Subject: [PATCH 038/158] fix CI again --- .github/workflows/checks.yml | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index c8337de1..4fdd5d14 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -30,18 +30,18 @@ jobs: strategy: matrix: versions: - - python: "3.8" - torch: "1.13.1" - - python: "3.9" - torch: "1.13.1" - - python: "3.10" - torch: "1.13.1" - - python: "3.10" - torch: "2.3.1" - - python: "3.11" - torch: "2.3.1" - # - python: "3.12" - # torch: "2.3.1" + - python: '3.8' + torch: '1.13.1' + - python: '3.9' + torch: '1.13.1' + - python: '3.10' + torch: '1.13.1' + - python: '3.10' + torch: '2.3.1' + - python: '3.11' + torch: '2.3.1' + # - python: '3.12' + # torch: '2.3.1' steps: - name: Checkout code uses: actions/checkout@v2 @@ -65,5 +65,5 @@ jobs: run: make test - name: lint - if: matrix.versions.python == '3.8' || matrix.versions.python == '3.9' + if: ${{ matrix.versions.python != '3.8' && matrix.versions.python != '3.9' }} run: make lint From e4095c51b5c7439c4e438e35dd375d32b3a0832c Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 01:11:09 -0700 Subject: [PATCH 039/158] wip --- muutils/mlutils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/muutils/mlutils.py b/muutils/mlutils.py index 57fdc800..70e29f7d 100644 --- a/muutils/mlutils.py +++ b/muutils/mlutils.py @@ -137,16 +137,19 @@ def register_method( """Decorator to add a method to the method_dict""" def decorator(method: F) -> F: + method_name_orig: str | None method_name: str if custom_name is None: - method_name: str | None = getattr(method, "__name__", None) - if method_name is None: + method_name_orig: str = getattr(method, "__name__", None) + if method_name_orig is None: warnings.warn( f"Method {method} does not have a name, using sanitized repr" ) from muutils.misc import sanitize_identifier method_name = sanitize_identifier(repr(method)) + else: + method_name = method_name_orig else: method_name = custom_name method.__name__ = custom_name From b9a30e2af8b7307818a64fcd35fff3fe47d720fe Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 01:11:55 -0700 Subject: [PATCH 040/158] fix sysinfo type hint --- muutils/sysinfo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/muutils/sysinfo.py b/muutils/sysinfo.py index b5736f1c..f846833a 100644 --- a/muutils/sysinfo.py +++ b/muutils/sysinfo.py @@ -152,7 +152,7 @@ def git_info(with_log: bool = False) -> dict: } else: - output: dict[str, str] = { + output: dict = { "git version": git_version["stdout"], "git status": git_status, "git branch": _popen(["git", "branch"], split_out=True), From bc672788a1b4eb438c97ce6db9a7c8b55fd1f406 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 01:12:57 -0700 Subject: [PATCH 041/158] fix statcounter type hint --- muutils/statcounter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/muutils/statcounter.py b/muutils/statcounter.py index bce0b105..99edc86b 100644 --- a/muutils/statcounter.py +++ b/muutils/statcounter.py @@ -25,7 +25,7 @@ def universal_flatten( if hasattr(arr, "flatten") and callable(arr.flatten): # type: ignore return arr.flatten() # type: ignore elif isinstance(arr, Sequence): - elements_iterable: List[bool] = [isinstance(x, Sequence) for x in arr] + elements_iterable: list[bool] = [isinstance(x, Sequence) for x in arr] if require_rectangular and (all(elements_iterable) != any(elements_iterable)): raise ValueError("arr contains mixed iterable and non-iterable elements") if any(elements_iterable): From 0591e9da93bcc5ed9341ed09691f180fb0381b59 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 01:14:04 -0700 Subject: [PATCH 042/158] more type fixes --- muutils/mlutils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/muutils/mlutils.py b/muutils/mlutils.py index 70e29f7d..ced9494e 100644 --- a/muutils/mlutils.py +++ b/muutils/mlutils.py @@ -137,10 +137,9 @@ def register_method( """Decorator to add a method to the method_dict""" def decorator(method: F) -> F: - method_name_orig: str | None method_name: str if custom_name is None: - method_name_orig: str = getattr(method, "__name__", None) + method_name_orig: str | None = getattr(method, "__name__", None) if method_name_orig is None: warnings.warn( f"Method {method} does not have a name, using sanitized repr" From 4634db373a59b72ceb5a4d0d16c8837d1e4d58a4 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 01:17:01 -0700 Subject: [PATCH 043/158] newer type hints should work with future annotations --- muutils/dictmagic.py | 69 +++++++++---------- muutils/json_serialize/json_serialize.py | 4 +- muutils/json_serialize/util.py | 6 +- muutils/sysinfo.py | 6 +- muutils/tensor_utils.py | 18 +++-- .../test_sdc_defaults.py | 4 +- 6 files changed, 51 insertions(+), 56 deletions(-) diff --git a/muutils/dictmagic.py b/muutils/dictmagic.py index c17ec0fc..3166eb50 100644 --- a/muutils/dictmagic.py +++ b/muutils/dictmagic.py @@ -6,7 +6,6 @@ from typing import ( Any, Callable, - Dict, Generic, Hashable, Iterable, @@ -21,7 +20,7 @@ _VT = TypeVar("_VT") -class DefaulterDict(Dict[_KT, _VT], Generic[_KT, _VT]): +class DefaulterDict(dict[_KT, _VT], Generic[_KT, _VT]): """like a defaultdict, but default_factory is passed the key as an argument""" def __init__(self, default_factory: Callable[[_KT], _VT], *args, **kwargs): @@ -55,7 +54,7 @@ def defaultdict_to_dict_recursive(dd: Union[defaultdict, DefaulterDict]) -> dict } -def dotlist_to_nested_dict(dot_dict: Dict[str, Any], sep: str = ".") -> Dict[str, Any]: +def dotlist_to_nested_dict(dot_dict: dict[str, Any], sep: str = ".") -> dict[str, Any]: """Convert a dict with dot-separated keys to a nested dict Example: @@ -76,11 +75,11 @@ def dotlist_to_nested_dict(dot_dict: Dict[str, Any], sep: str = ".") -> Dict[str def nested_dict_to_dotlist( - nested_dict: Dict[str, Any], + nested_dict: dict[str, Any], sep: str = ".", allow_lists: bool = False, -) -> Dict[str, Any]: - def _recurse(current: Any, parent_key: str = "") -> Dict[str, Any]: +) -> dict[str, Any]: + def _recurse(current: Any, parent_key: str = "") -> dict[str, Any]: items: dict = dict() new_key: str @@ -109,9 +108,9 @@ def _recurse(current: Any, parent_key: str = "") -> Dict[str, Any]: def update_with_nested_dict( - original: Dict[str, Any], - update: Dict[str, Any], -) -> Dict[str, Any]: + original: dict[str, Any], + update: dict[str, Any], +) -> dict[str, Any]: """Update a dict with a nested dict Example: @@ -119,9 +118,9 @@ def update_with_nested_dict( {'a': {'b': 2}, 'c': -1} # Arguments - - `original: Dict[str, Any]` + - `original: dict[str, Any]` the dict to update (will be modified in-place) - - `update: Dict[str, Any]` + - `update: dict[str, Any]` the dict to update with # Returns @@ -141,12 +140,12 @@ def update_with_nested_dict( def kwargs_to_nested_dict( - kwargs_dict: Dict[str, Any], + kwargs_dict: dict[str, Any], sep: str = ".", strip_prefix: Optional[str] = None, when_unknown_prefix: typing.Literal["raise", "warn", "ignore"] = "warn", transform_key: Optional[Callable[[str], str]] = None, -) -> Dict[str, Any]: +) -> dict[str, Any]: """given kwargs from fire, convert them to a nested dict if strip_prefix is not None, then all keys must start with the prefix. by default, @@ -166,7 +165,7 @@ def main(**kwargs): ``` # Arguments - - `kwargs_dict: Dict[str, Any]` + - `kwargs_dict: dict[str, Any]` the kwargs dict to convert - `sep: str = "."` the separator to use for nested keys @@ -177,7 +176,7 @@ def main(**kwargs): - `transform_key: Callable[[str], str] | None = None` a function to apply to each key before adding it to the dict (applied after stripping the prefix) """ - filtered_kwargs: Dict[str, Any] = dict() + filtered_kwargs: dict[str, Any] = dict() for key, value in kwargs_dict.items(): if strip_prefix is not None: if not key.startswith(strip_prefix): @@ -212,8 +211,8 @@ def is_numeric_consecutive(lst: list[str]) -> bool: def condense_nested_dicts_numeric_keys( - data: Dict[str, Any], -) -> Dict[str, Any]: + data: dict[str, Any], +) -> dict[str, Any]: """condense a nested dict, by condensing numeric keys with matching values to ranges # Examples: @@ -239,7 +238,7 @@ def condense_nested_dicts_numeric_keys( return data # output dict - condensed_data: Dict[str, Any] = {} + condensed_data: dict[str, Any] = {} # Identify ranges of identical values and condense i: int = 0 @@ -259,15 +258,15 @@ def condense_nested_dicts_numeric_keys( def condense_nested_dicts_matching_values( - data: Dict[str, Any], + data: dict[str, Any], val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None, -) -> Dict[str, Any]: +) -> dict[str, Any]: """condense a nested dict, by condensing keys with matching values # Examples: # Parameters: - - `data : Dict[str, Any]` + - `data : dict[str, Any]` data to process - `val_condense_fallback_mapping : Callable[[Any], Hashable] | None` a function to apply to each value before adding it to the dict (if it's not hashable) @@ -287,7 +286,7 @@ def condense_nested_dicts_matching_values( # Find all identical values and condense by stitching together keys values_grouped: defaultdict[Any, list[str]] = defaultdict(list) - data_persist: Dict[str, Any] = dict() + data_persist: dict[str, Any] = dict() for key, value in data.items(): if not isinstance(value, dict): try: @@ -313,11 +312,11 @@ def condense_nested_dicts_matching_values( def condense_nested_dicts( - data: Dict[str, Any], + data: dict[str, Any], condense_numeric_keys: bool = True, condense_matching_values: bool = True, val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None, -) -> Dict[str, Any]: +) -> dict[str, Any]: """condense a nested dict, by condensing numeric or matching keys with matching values to ranges combines the functionality of `condense_nested_dicts_numeric_keys()` and `condense_nested_dicts_matching_values()` @@ -326,7 +325,7 @@ def condense_nested_dicts( it's not reversible because types are lost to make the printing pretty # Parameters: - - `data : Dict[str, Any]` + - `data : dict[str, Any]` data to process - `condense_numeric_keys : bool` whether to condense numeric keys (e.g. "1", "2", "3") to ranges (e.g. "[1-3]") @@ -351,7 +350,7 @@ def condense_nested_dicts( def tuple_dims_replace( - t: Tuple[int, ...], dims_names_map: Optional[Dict[int, str]] = None + t: Tuple[int, ...], dims_names_map: Optional[dict[int, str]] = None ) -> Tuple[Union[int, str], ...]: if dims_names_map is None: return t @@ -359,7 +358,7 @@ def tuple_dims_replace( return tuple(dims_names_map.get(x, x) for x in t) -TensorDict = Dict[str, "torch.Tensor|np.ndarray"] # type: ignore[name-defined] +TensorDict = dict[str, "torch.Tensor|np.ndarray"] # type: ignore[name-defined] TensorIterable = Iterable[Tuple[str, "torch.Tensor|np.ndarray"]] # type: ignore[name-defined] TensorDictFormats = Literal["dict", "json", "yaml", "yml"] @@ -375,19 +374,19 @@ def condense_tensor_dict( shapes_convert: Callable[[tuple], Any] = _default_shapes_convert, drop_batch_dims: int = 0, sep: str = ".", - dims_names_map: Optional[Dict[int, str]] = None, + dims_names_map: Optional[dict[int, str]] = None, condense_numeric_keys: bool = True, condense_matching_values: bool = True, val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None, return_format: Optional[TensorDictFormats] = None, -) -> Union[str, Dict[str, str | Tuple[int, ...]]]: +) -> Union[str, dict[str, str | Tuple[int, ...]]]: """Convert a dictionary of tensors to a dictionary of shapes. by default, values are converted to strings of their shapes (for nice printing). If you want the actual shapes, set `shapes_convert = lambda x: x` or `shapes_convert = None`. # Parameters: - - `data : Dict[str, "torch.Tensor|np.ndarray"] | Iterable[Tuple[str, "torch.Tensor|np.ndarray"]]` + - `data : dict[str, "torch.Tensor|np.ndarray"] | Iterable[Tuple[str, "torch.Tensor|np.ndarray"]]` a either a `TensorDict` dict from strings to tensors, or an `TensorIterable` iterable of (key, tensor) pairs (like you might get from a `dict().items())` ) - `fmt : TensorDictFormats` format to return the result in -- either a dict, or dump to json/yaml directly for pretty printing. will crash if yaml is not installed. @@ -401,7 +400,7 @@ def condense_tensor_dict( - `sep : str` separator to use for nested keys (defaults to `'.'`) - - `dims_names_map : Dict[int, str] | None` + - `dims_names_map : dict[int, str] | None` convert certain dimension values in shape. not perfect, can be buggy (defaults to `None`) - `condense_numeric_keys : bool` @@ -417,7 +416,7 @@ def condense_tensor_dict( legacy alias for `fmt` kwarg # Returns: - - `str|Dict[str, str|Tuple[int, ...]]` + - `str|dict[str, str|Tuple[int, ...]]` dict if `return_format='dict'`, a string for `json` or `yaml` output # Examples: @@ -473,7 +472,7 @@ def condense_tensor_dict( ) # get shapes - data_shapes: Dict[str, Union[str, Tuple[int, ...]]] = { + data_shapes: dict[str, Union[str, Tuple[int, ...]]] = { k: shapes_convert( tuple_dims_replace( tuple(v.shape)[drop_batch_dims:], @@ -484,10 +483,10 @@ def condense_tensor_dict( } # nest the dict - data_nested: Dict[str, Any] = dotlist_to_nested_dict(data_shapes, sep=sep) + data_nested: dict[str, Any] = dotlist_to_nested_dict(data_shapes, sep=sep) # condense the nested dict - data_condensed: Dict[str, Union[str, Tuple[int, ...]]] = condense_nested_dicts( + data_condensed: dict[str, Union[str, Tuple[int, ...]]] = condense_nested_dicts( data=data_nested, condense_numeric_keys=condense_numeric_keys, condense_matching_values=condense_matching_values, diff --git a/muutils/json_serialize/json_serialize.py b/muutils/json_serialize/json_serialize.py index 42d1f292..6f290436 100644 --- a/muutils/json_serialize/json_serialize.py +++ b/muutils/json_serialize/json_serialize.py @@ -4,7 +4,7 @@ import warnings from dataclasses import dataclass, is_dataclass from pathlib import Path -from typing import Any, Callable, Dict, Iterable, Mapping, Set, Union +from typing import Any, Callable, Iterable, Mapping, Set, Union try: from muutils.json_serialize.array import ArrayMode, serialize_array @@ -40,7 +40,7 @@ "__annotations__", ) -SERIALIZER_SPECIAL_FUNCS: Dict[str, Callable] = { +SERIALIZER_SPECIAL_FUNCS: dict[str, Callable] = { "str": str, "dir": dir, "type": try_catch(lambda x: str(type(x).__name__)), diff --git a/muutils/json_serialize/util.py b/muutils/json_serialize/util.py index 287675a2..63d44479 100644 --- a/muutils/json_serialize/util.py +++ b/muutils/json_serialize/util.py @@ -5,7 +5,7 @@ import sys import typing import warnings -from typing import Any, Callable, Dict, Iterable, Literal, Union +from typing import Any, Callable, Iterable, Literal, Union _NUMPY_WORKING: bool try: @@ -18,8 +18,8 @@ TypeErrorMode = Union[ErrorMode, Literal["try_convert"]] -JSONitem = Union[bool, int, float, str, list, Dict[str, Any], None] -JSONdict = Dict[str, JSONitem] +JSONitem = Union[bool, int, float, str, list, dict[str, Any], None] +JSONdict = dict[str, JSONitem] Hashableitem = Union[bool, int, float, str, tuple] # or if python version <3.9 diff --git a/muutils/sysinfo.py b/muutils/sysinfo.py index f846833a..bd36b690 100644 --- a/muutils/sysinfo.py +++ b/muutils/sysinfo.py @@ -7,14 +7,12 @@ from importlib.metadata import distributions -def _popen( - cmd: typing.List[str], split_out: bool = False -) -> typing.Dict[str, typing.Any]: +def _popen(cmd: list[str], split_out: bool = False) -> dict[str, typing.Any]: p: subprocess.Popen = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) - p_out: typing.Union[str, typing.List[str], None] + p_out: typing.Union[str, list[str], None] if p.stdout is not None: p_out = p.stdout.read().decode("utf-8") if split_out: diff --git a/muutils/tensor_utils.py b/muutils/tensor_utils.py index 6be949dc..ce3b38df 100644 --- a/muutils/tensor_utils.py +++ b/muutils/tensor_utils.py @@ -124,7 +124,7 @@ def __class_getitem__(cls, params: typing.Union[str, tuple]) -> type: f"legacy type annotation was used:\n{cls.param_info(params)}" ) # MyTensor[("dim1", "dim2"), int] - shape_anot: typing.List[str] = list() + shape_anot: list[str] = list() for x in params[0]: if isinstance(x, str): shape_anot.append(x) @@ -261,7 +261,7 @@ def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dt TORCH_DTYPE_MAP["bool"] = torch.bool -TORCH_OPTIMIZERS_MAP: typing.Dict[str, typing.Type[torch.optim.Optimizer]] = { +TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.Optimizer]] = { "Adagrad": torch.optim.Adagrad, "Adam": torch.optim.Adam, "AdamW": torch.optim.AdamW, @@ -287,7 +287,7 @@ def pad_tensor( set `rpad = True` to pad on the right instead""" - temp: typing.List[torch.Tensor] = [ + temp: list[torch.Tensor] = [ torch.full( (padded_length - tensor.shape[0],), pad_value, @@ -326,7 +326,7 @@ def pad_array( set `rpad = True` to pad on the right instead""" - temp: typing.List[np.ndarray] = [ + temp: list[np.ndarray] = [ np.full( (padded_length - array.shape[0],), pad_value, @@ -355,14 +355,12 @@ def rpad_array( return pad_array(array, pad_length, pad_value, rpad=True) -def get_dict_shapes( - d: typing.Dict[str, "torch.Tensor"] -) -> typing.Dict[str, typing.Tuple[int, ...]]: +def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, typing.Tuple[int, ...]]: """given a state dict or cache dict, compute the shapes and put them in a nested dict""" return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()}) -def string_dict_shapes(d: typing.Dict[str, "torch.Tensor"]) -> str: +def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str: """printable version of get_dict_shapes""" return json.dumps( dotlist_to_nested_dict( @@ -430,8 +428,8 @@ def compare_state_dicts( ) # check tensors match - shape_failed: typing.List[str] = list() - vals_failed: typing.List[str] = list() + shape_failed: list[str] = list() + vals_failed: list[str] = list() for k, v1 in d1.items(): v2 = d2[k] # check shapes first diff --git a/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py b/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py index 91c01b72..7ea89c7a 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py @@ -1,7 +1,7 @@ from __future__ import annotations import sys -from typing import Any, Dict +from typing import Any import pytest @@ -65,7 +65,7 @@ def test_sdc_strip_format_jser(): assert recovered == instance -TYPE_MAP: Dict[str, type] = {x.__name__: x for x in [int, float, str, bool]} +TYPE_MAP: dict[str, type] = {x.__name__: x for x in [int, float, str, bool]} @serializable_dataclass From 53cc8f4678806f2975721c381c515da02083de7d Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 01:19:17 -0700 Subject: [PATCH 044/158] Tuple[] -> tuple[] --- muutils/dictmagic.py | 19 +++++++++---------- muutils/sysinfo.py | 6 +++--- muutils/tensor_utils.py | 2 +- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/muutils/dictmagic.py b/muutils/dictmagic.py index 3166eb50..37e5ecfe 100644 --- a/muutils/dictmagic.py +++ b/muutils/dictmagic.py @@ -11,7 +11,6 @@ Iterable, Literal, Optional, - Tuple, TypeVar, Union, ) @@ -350,8 +349,8 @@ def condense_nested_dicts( def tuple_dims_replace( - t: Tuple[int, ...], dims_names_map: Optional[dict[int, str]] = None -) -> Tuple[Union[int, str], ...]: + t: tuple[int, ...], dims_names_map: Optional[dict[int, str]] = None +) -> tuple[Union[int, str], ...]: if dims_names_map is None: return t else: @@ -359,7 +358,7 @@ def tuple_dims_replace( TensorDict = dict[str, "torch.Tensor|np.ndarray"] # type: ignore[name-defined] -TensorIterable = Iterable[Tuple[str, "torch.Tensor|np.ndarray"]] # type: ignore[name-defined] +TensorIterable = Iterable[tuple[str, "torch.Tensor|np.ndarray"]] # type: ignore[name-defined] TensorDictFormats = Literal["dict", "json", "yaml", "yml"] @@ -379,14 +378,14 @@ def condense_tensor_dict( condense_matching_values: bool = True, val_condense_fallback_mapping: Optional[Callable[[Any], Hashable]] = None, return_format: Optional[TensorDictFormats] = None, -) -> Union[str, dict[str, str | Tuple[int, ...]]]: +) -> Union[str, dict[str, str | tuple[int, ...]]]: """Convert a dictionary of tensors to a dictionary of shapes. by default, values are converted to strings of their shapes (for nice printing). If you want the actual shapes, set `shapes_convert = lambda x: x` or `shapes_convert = None`. # Parameters: - - `data : dict[str, "torch.Tensor|np.ndarray"] | Iterable[Tuple[str, "torch.Tensor|np.ndarray"]]` + - `data : dict[str, "torch.Tensor|np.ndarray"] | Iterable[tuple[str, "torch.Tensor|np.ndarray"]]` a either a `TensorDict` dict from strings to tensors, or an `TensorIterable` iterable of (key, tensor) pairs (like you might get from a `dict().items())` ) - `fmt : TensorDictFormats` format to return the result in -- either a dict, or dump to json/yaml directly for pretty printing. will crash if yaml is not installed. @@ -416,7 +415,7 @@ def condense_tensor_dict( legacy alias for `fmt` kwarg # Returns: - - `str|dict[str, str|Tuple[int, ...]]` + - `str|dict[str, str|tuple[int, ...]]` dict if `return_format='dict'`, a string for `json` or `yaml` output # Examples: @@ -467,12 +466,12 @@ def condense_tensor_dict( shapes_convert = lambda x: x # convert to iterable - data_items: "Iterable[Tuple[str, Union[torch.Tensor,np.ndarray]]]" = ( # type: ignore + data_items: "Iterable[tuple[str, Union[torch.Tensor,np.ndarray]]]" = ( # type: ignore data.items() if hasattr(data, "items") and callable(data.items) else data # type: ignore ) # get shapes - data_shapes: dict[str, Union[str, Tuple[int, ...]]] = { + data_shapes: dict[str, Union[str, tuple[int, ...]]] = { k: shapes_convert( tuple_dims_replace( tuple(v.shape)[drop_batch_dims:], @@ -486,7 +485,7 @@ def condense_tensor_dict( data_nested: dict[str, Any] = dotlist_to_nested_dict(data_shapes, sep=sep) # condense the nested dict - data_condensed: dict[str, Union[str, Tuple[int, ...]]] = condense_nested_dicts( + data_condensed: dict[str, Union[str, tuple[int, ...]]] = condense_nested_dicts( data=data_nested, condense_numeric_keys=condense_numeric_keys, condense_matching_values=condense_matching_values, diff --git a/muutils/sysinfo.py b/muutils/sysinfo.py index bd36b690..f6835461 100644 --- a/muutils/sysinfo.py +++ b/muutils/sysinfo.py @@ -164,10 +164,10 @@ def git_info(with_log: bool = False) -> dict: @classmethod def get_all( cls, - include: typing.Optional[typing.Tuple[str, ...]] = None, - exclude: typing.Tuple[str, ...] = tuple(), + include: typing.Optional[tuple[str, ...]] = None, + exclude: tuple[str, ...] = tuple(), ) -> dict: - include_meta: typing.Tuple[str, ...] + include_meta: tuple[str, ...] if include is None: include_meta = tuple(cls.__dict__.keys()) else: diff --git a/muutils/tensor_utils.py b/muutils/tensor_utils.py index ce3b38df..c84f15a7 100644 --- a/muutils/tensor_utils.py +++ b/muutils/tensor_utils.py @@ -355,7 +355,7 @@ def rpad_array( return pad_array(array, pad_length, pad_value, rpad=True) -def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, typing.Tuple[int, ...]]: +def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]: """given a state dict or cache dict, compute the shapes and put them in a nested dict""" return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()}) From 42c4eee51bf0c2befcff77e4dc0e5962027a42e5 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 01:29:08 -0700 Subject: [PATCH 045/158] fix types, for real this time! --- muutils/dictmagic.py | 14 ++++++++------ muutils/json_serialize/util.py | 4 ++-- muutils/tensor_utils.py | 2 ++ tests/unit/test_dictmagic.py | 2 ++ 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/muutils/dictmagic.py b/muutils/dictmagic.py index 37e5ecfe..28ea0d02 100644 --- a/muutils/dictmagic.py +++ b/muutils/dictmagic.py @@ -19,7 +19,7 @@ _VT = TypeVar("_VT") -class DefaulterDict(dict[_KT, _VT], Generic[_KT, _VT]): +class DefaulterDict(typing.Dict[_KT, _VT], Generic[_KT, _VT]): """like a defaultdict, but default_factory is passed the key as an argument""" def __init__(self, default_factory: Callable[[_KT], _VT], *args, **kwargs): @@ -53,7 +53,9 @@ def defaultdict_to_dict_recursive(dd: Union[defaultdict, DefaulterDict]) -> dict } -def dotlist_to_nested_dict(dot_dict: dict[str, Any], sep: str = ".") -> dict[str, Any]: +def dotlist_to_nested_dict( + dot_dict: typing.Dict[str, Any], sep: str = "." +) -> typing.Dict[str, Any]: """Convert a dict with dot-separated keys to a nested dict Example: @@ -74,11 +76,11 @@ def dotlist_to_nested_dict(dot_dict: dict[str, Any], sep: str = ".") -> dict[str def nested_dict_to_dotlist( - nested_dict: dict[str, Any], + nested_dict: typing.Dict[str, Any], sep: str = ".", allow_lists: bool = False, ) -> dict[str, Any]: - def _recurse(current: Any, parent_key: str = "") -> dict[str, Any]: + def _recurse(current: Any, parent_key: str = "") -> typing.Dict[str, Any]: items: dict = dict() new_key: str @@ -357,8 +359,8 @@ def tuple_dims_replace( return tuple(dims_names_map.get(x, x) for x in t) -TensorDict = dict[str, "torch.Tensor|np.ndarray"] # type: ignore[name-defined] -TensorIterable = Iterable[tuple[str, "torch.Tensor|np.ndarray"]] # type: ignore[name-defined] +TensorDict = typing.Dict[str, "torch.Tensor|np.ndarray"] # type: ignore[name-defined] +TensorIterable = Iterable[typing.Tuple[str, "torch.Tensor|np.ndarray"]] # type: ignore[name-defined] TensorDictFormats = Literal["dict", "json", "yaml", "yml"] diff --git a/muutils/json_serialize/util.py b/muutils/json_serialize/util.py index 63d44479..87d80d03 100644 --- a/muutils/json_serialize/util.py +++ b/muutils/json_serialize/util.py @@ -18,8 +18,8 @@ TypeErrorMode = Union[ErrorMode, Literal["try_convert"]] -JSONitem = Union[bool, int, float, str, list, dict[str, Any], None] -JSONdict = dict[str, JSONitem] +JSONitem = Union[bool, int, float, str, list, typing.Dict[str, Any], None] +JSONdict = typing.Dict[str, JSONitem] Hashableitem = Union[bool, int, float, str, tuple] # or if python version <3.9 diff --git a/muutils/tensor_utils.py b/muutils/tensor_utils.py index c84f15a7..8722546e 100644 --- a/muutils/tensor_utils.py +++ b/muutils/tensor_utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import typing import warnings diff --git a/tests/unit/test_dictmagic.py b/tests/unit/test_dictmagic.py index 1375a029..4407bb54 100644 --- a/tests/unit/test_dictmagic.py +++ b/tests/unit/test_dictmagic.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from muutils.dictmagic import ( From c2a2137e709e93ad05c58cf72c2919e3da661720 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 01:48:40 -0700 Subject: [PATCH 046/158] making typing work for python <3.10 --- .github/workflows/checks.yml | 8 ++++++-- makefile | 14 ++++++++++++-- muutils/json_serialize/serializable_dataclass.py | 2 +- muutils/sysinfo.py | 3 ++- 4 files changed, 21 insertions(+), 6 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 4fdd5d14..bef6496e 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -64,6 +64,10 @@ jobs: - name: tests run: make test - - name: lint + - name: check typing if: ${{ matrix.versions.python != '3.8' && matrix.versions.python != '3.9' }} - run: make lint + run: make typing + + - name: check typing in compatibility mode + if: ${{ matrix.versions.python == '3.8' || matrix.versions.python == '3.9' }} + run: make typing-compat \ No newline at end of file diff --git a/makefile b/makefile index ee237e95..d0b2a866 100644 --- a/makefile +++ b/makefile @@ -21,6 +21,9 @@ COMMIT_LOG_FILE := .commit_log COMMIT_LOG_SINCE_LAST_VERSION := $(shell (git log $(LAST_VERSION)..HEAD --pretty=format:"- %s (%h)" | tr '`' "'" ; echo) | tac | tr '\n' '\t') # 1 2 3 4 5 +TYPECHECK_COMPAT_ARGS := --disable-error-code misc --disable-error-code syntax --disable-error-code import-not-found + + .PHONY: default default: help @@ -77,11 +80,18 @@ cov: # not sure how to fix this # python -m pylint $(PACKAGE_NAME)/ # python -m pylint tests/ -.PHONY: lint -lint: clean +.PHONY: typing +typing: clean + @echo "running type checks" $(PYPOETRY) -m mypy --config-file $(PYPROJECT) $(PACKAGE_NAME)/ $(PYPOETRY) -m mypy --config-file $(PYPROJECT) tests/ +.PHONY: typing-compat +typing-compat: clean + @echo "running type checks in compatibility mode for older python versions" + $(PYPOETRY) -m mypy --config-file $(PYPROJECT) $(TYPECHECK_COMPAT_ARGS) $(PACKAGE_NAME)/ + $(PYPOETRY) -m mypy --config-file $(PYPROJECT) $(TYPECHECK_COMPAT_ARGS) tests/ + .PHONY: test test: clean @echo "running tests" diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index 4dd4da30..dfe025bd 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -101,7 +101,7 @@ def from_Field(cls, field: dataclasses.Field) -> "SerializableField": hash=field.hash, compare=field.compare, metadata=field.metadata, - kw_only=field.kw_only, + kw_only=getattr(field, "kw_only", dataclasses.MISSING), # for python <3.9 serialize=field.repr, serialization_fn=None, loading_fn=None, diff --git a/muutils/sysinfo.py b/muutils/sysinfo.py index f6835461..c3b8d4e4 100644 --- a/muutils/sysinfo.py +++ b/muutils/sysinfo.py @@ -48,7 +48,8 @@ def python() -> dict: @staticmethod def pip() -> dict: """installed packages info""" - pckgs: list[tuple[str, str]] = [(x.name, x.version) for x in distributions()] + # for some reason, python 3.8 thinks `Distribution` has no attribute `name`? + pckgs: list[tuple[str, str]] = [(x.name, x.version) for x in distributions()] # type: ignore[attr-defined] return { "n_packages": len(pckgs), "packages": pckgs, From 39dc23d3d8341cdf3af13e29e28020623accddf9 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 02:37:53 -0700 Subject: [PATCH 047/158] wipgsgs --- .../json_serialize/serializable_dataclass.py | 18 +- muutils/validate_type.py | 46 ++ test.ipynb | 448 ++++++++++++++++++ tests/unit/test_validate_type.py | 153 ++++++ 4 files changed, 659 insertions(+), 6 deletions(-) create mode 100644 muutils/validate_type.py create mode 100644 test.ipynb create mode 100644 tests/unit/test_validate_type.py diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index dfe025bd..01f4e426 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -403,13 +403,14 @@ def serializable_dataclass( _cls=None, # type: ignore *, init: bool = True, - repr: bool = True, + repr: bool = True, # TODO: this overrides the actual `repr` method, can this be fixed? eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, properties_to_serialize: Optional[list[str]] = None, register_handler: bool = True, + on_type_assert: typing.Literal["raise", "warn", "ignore"] = "warn", # TODO: change default to "raise" once more stable **kwargs, ): # -> Union[Callable[[Type[T]], Type[T]], Type[T]]: @@ -565,14 +566,21 @@ def load(cls, data: dict[str, Any] | T) -> Type[T]: # assume no loading needs to happen, keep `value` as-is pass + + # store the value in the constructor kwargs + ctor_kwargs[field.name] = value + # validate the type if field.assert_type: if field.name in ctor_kwargs: if field_type_hint is not None: # TODO: recursive type hint checking like pydantic? - assert isinstance( - ctor_kwargs[field.name], field_type_hint - ) + try: + assert _validate_type(ctor_kwargs[field.name], field_type_hint) + except Exception as e: + raise ValueError( + f"{field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {ctor_kwargs[field.name] = }" + ) from e else: raise ValueError( f"Cannot get type hints for {cls.__name__}, and so cannot validate. Python version is {sys.version_info = }. You can:\n" @@ -587,8 +595,6 @@ def load(cls, data: dict[str, Any] | T) -> Type[T]: f"Field '{field.name}' on class {cls} has no type hint, but {field.assert_type = }\n{field = }\n{cls_type_hints = }\n{data = }" ) - # store the value in the constructor kwargs - ctor_kwargs[field.name] = value return cls(**ctor_kwargs) diff --git a/muutils/validate_type.py b/muutils/validate_type.py new file mode 100644 index 00000000..c6fa6a16 --- /dev/null +++ b/muutils/validate_type.py @@ -0,0 +1,46 @@ +import types +import typing + +def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: + if expected_type is typing.Any: + return True + + # base type without args + if isinstance(expected_type, type): + return isinstance(value, expected_type) + + origin: type = typing.get_origin(expected_type) + args: list = typing.get_args(expected_type) + + print(f"{origin = } {args = }") + + if origin is types.UnionType: + return any(validate_type(value, arg) for arg in args) + + # generic alias, more complicated + if isinstance(expected_type, (typing.GenericAlias, typing._GenericAlias, typing._UnionGenericAlias)): + + if origin is list: + assert len(args) == 1 + return isinstance(value, list) and all(validate_type(item, args[0]) for item in value) + + if origin is dict: + assert len(args) == 2 + return isinstance(value, dict) and all( + validate_type(key, args[0]) and validate_type(val, args[1]) + for key, val in value.items() + ) + + if origin is set: + assert len(args) == 1 + return isinstance(value, set) and all(validate_type(item, args[0]) for item in value) + + if origin is tuple: + if len(value) != len(args): + return False + return all(validate_type(item, arg) for item, arg in zip(value, args)) + + raise ValueError(f"Unsupported generic alias {expected_type}") + + else: + raise ValueError(f"Unsupported type hint {expected_type = } for {value = }") \ No newline at end of file diff --git a/test.ipynb b/test.ipynb new file mode 100644 index 00000000..0041f15a --- /dev/null +++ b/test.ipynb @@ -0,0 +1,448 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import typing" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "x = list[str]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "list[str]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "types.GenericAlias" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "type(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "isinstance(x, typing.GenericAlias)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "y = str|int" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "isinstance(y, type)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "type" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "type(y.__args__[0])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Iterable, Sequence, Type" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "isinstance(list, type(Iterable))" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "list[str, int]" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "list[str, int]" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "typing.Any" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "typing.Any" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "origin = args = (, )\n", + "origin = args = (, )\n", + "origin = args = (,)\n", + "origin = args = (,)\n", + "origin = args = (,)\n", + "origin = args = (,)\n" + ] + } + ], + "source": [ + "from typing import Any, Union\n", + "from types import UnionType\n", + "import types\n", + "\n", + "def _validate_type(value: Any, expected_type: Any) -> bool:\n", + " if expected_type is Any:\n", + " return True\n", + " \n", + " # base type without args\n", + " if isinstance(expected_type, type):\n", + " return isinstance(value, expected_type)\n", + "\n", + " origin: type = typing.get_origin(expected_type)\n", + " args: list = typing.get_args(expected_type)\n", + " \n", + " print(f\"{origin = } {args = }\")\n", + "\n", + " if origin is types.UnionType:\n", + " return any(_validate_type(value, arg) for arg in args)\n", + "\n", + " # generic alias, more complicated\n", + " if isinstance(expected_type, (typing.GenericAlias, typing._GenericAlias)):\n", + "\n", + " if origin is list:\n", + " assert len(args) == 1\n", + " return isinstance(value, list) and all(_validate_type(item, args[0]) for item in value)\n", + " \n", + " if origin is dict:\n", + " assert len(args) == 2\n", + " return isinstance(value, dict) and all(\n", + " _validate_type(key, args[0]) and _validate_type(val, args[1])\n", + " for key, val in value.items()\n", + " )\n", + " \n", + " if origin is set:\n", + " assert len(args) == 1\n", + " return isinstance(value, set) and all(_validate_type(item, args[0]) for item in value)\n", + " \n", + " if origin is tuple:\n", + " if len(value) != len(args):\n", + " return False\n", + " return all(_validate_type(item, arg) for item, arg in zip(value, args))\n", + " \n", + " raise ValueError(f\"Unsupported generic alias {expected_type}\")\n", + "\n", + " else:\n", + " raise ValueError(f\"Unsupported type hint {expected_type = } for {value = }\")\n", + " \n", + "assert _validate_type(1, str|int)\n", + "assert _validate_type(\"a\", str|int)\n", + "assert _validate_type([1, 2, 3], list[int])\n", + "assert not _validate_type([1, 2, 3], list[str])\n", + "assert _validate_type({\"a\", \"b\", \"c\"}, set[str])\n", + "assert not _validate_type({\"a\", \"b\", 1}, set[int])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "str | int" + ] + }, + "execution_count": 61, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [], + "source": [ + "z = list[int]" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "metadata": {}, + "outputs": [], + "source": [ + "z1 = typing.Union[int, str]" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "typing._UnionGenericAlias" + ] + }, + "execution_count": 82, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "type(z1)" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "typing.Union[int, str]" + ] + }, + "execution_count": 80, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "z1" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "typing.Union" + ] + }, + "execution_count": 83, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "typing.get_origin(z1)" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 72, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "isinstance(z1, typing._GenericAlias)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "y_o = typing.get_origin(y)\n", + "y_a = typing.get_args(y)" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import types\n", + "types.UnionType\n", + "\n", + "y_o is types.UnionType" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/unit/test_validate_type.py b/tests/unit/test_validate_type.py new file mode 100644 index 00000000..c899cc44 --- /dev/null +++ b/tests/unit/test_validate_type.py @@ -0,0 +1,153 @@ +import pytest + +import typing + +from muutils.validate_type import validate_type +import types +import typing +from typing import Any, Dict, List, Set, Tuple, Union + + +def testvalidate_type_basic_types(): + assert validate_type(42, int) + assert validate_type(3.14, float) + assert validate_type("hello", str) + assert validate_type(True, bool) + assert validate_type([1, 2, 3], list) + assert validate_type({'a': 1, 'b': 2}, dict) + assert validate_type({1, 2, 3}, set) + assert validate_type((1, 2, 3), tuple) + +def testvalidate_type_any(): + assert validate_type(42, Any) + assert validate_type("hello", Any) + assert validate_type([1, 2, 3], Any) + +def testvalidate_type_union(): + assert validate_type(42, Union[int, str]) + assert validate_type("hello", Union[int, str]) + assert validate_type(3.14, Union[int, float]) + assert not validate_type(True, Union[int, str]) + +def testvalidate_type_list(): + assert validate_type([1, 2, 3], List[int]) + assert validate_type(["a", "b", "c"], List[str]) + assert not validate_type([1, "a", 3], List[int]) + assert not validate_type(42, List[int]) + +def testvalidate_type_dict(): + assert validate_type({'a': 1, 'b': 2}, Dict[str, int]) + assert validate_type({1: 'a', 2: 'b'}, Dict[int, str]) + assert not validate_type({'a': 1, 'b': 'c'}, Dict[str, int]) + assert not validate_type([('a', 1), ('b', 2)], Dict[str, int]) + +def testvalidate_type_set(): + assert validate_type({1, 2, 3}, Set[int]) + assert validate_type({"a", "b", "c"}, Set[str]) + assert not validate_type({1, "a", 3}, Set[int]) + assert not validate_type([1, 2, 3], Set[int]) + +def testvalidate_type_tuple(): + assert validate_type((1, 'a', 3.14), Tuple[int, str, float]) + assert validate_type(('a', 'b', 'c'), Tuple[str, str, str]) + assert not validate_type((1, 'a', 3.14), Tuple[int, str]) + assert not validate_type([1, 'a', 3.14], Tuple[int, str, float]) + +def testvalidate_type_unsupported_type_hint(): + with pytest.raises(ValueError, match="Unsupported type hint"): + validate_type(42, typing.Callable[[], None]) + +def testvalidate_type_unsupported_generic_alias(): + with pytest.raises(ValueError, match="Unsupported generic alias"): + validate_type([1, 2, 3], List[int, str]) + +def testvalidate_type_edge_cases(): + assert validate_type([], List[int]) + assert validate_type({}, Dict[str, int]) + assert validate_type(set(), Set[int]) + # assert validate_type((), Tuple[]) + + assert not validate_type(42, List[int]) + assert not validate_type("hello", Dict[str, int]) + assert not validate_type([1, 2], Tuple[int, int, int]) + + assert validate_type([1, 2, [3, 4]], List[Union[int, List[int]]]) + assert validate_type({'a': 1, 'b': {'c': 2}}, Dict[str, Union[int, Dict[str, int]]]) + assert validate_type({1, (2, 3)}, Set[Union[int, Tuple[int, int]]]) + assert validate_type((1, ('a', 'b')), Tuple[int, Tuple[str, str]]) + +def testvalidate_type(): + assert validate_type(5, int) == True + assert validate_type(5.0, int) == False + assert validate_type("hello", str) == True + assert validate_type("hello", typing.Any) == True + assert validate_type([1, 2, 3], typing.List[int]) == True + assert validate_type([1, "2", 3], typing.List[int]) == False + assert validate_type({"key": "value"}, typing.Dict[str, str]) == True + assert validate_type({"key": 2}, typing.Dict[str, str]) == False + assert validate_type((1, "a"), typing.Tuple[int, str]) == True + assert validate_type((1, 2), typing.Tuple[int, str]) == False + assert validate_type((1, 2), typing.Tuple[int, int]) == True + assert validate_type((1, 2, 3), typing.Tuple[int, int]) == False + assert validate_type(5, typing.Union[int, str]) == True + assert validate_type("hello", typing.Union[int, str]) == True + assert validate_type(5.0, typing.Union[int, str]) == False + assert validate_type(None, typing.Union[int, type(None)]) == True + assert validate_type(None, typing.Union[int, type(None), str]) == True + assert validate_type(None, typing.Union[int, str]) == False + assert validate_type({"key": 2}, typing.Dict[str, int]) == True + assert validate_type({"key": 2.0}, typing.Dict[str, int]) == False + assert validate_type([{"key": "value"}], typing.List[typing.Dict[str, str]]) == True + assert validate_type([{"key": 2}], typing.List[typing.Dict[str, str]]) == False + assert validate_type([[1, 2], [3, 4]], typing.List[typing.List[int]]) == True + assert validate_type([[1, 2], [3, "4"]], typing.List[typing.List[int]]) == False + assert validate_type([(1, 2), (3, 4)], typing.List[typing.Tuple[int, int]]) == True + assert validate_type([(1, 2), (3, "4")], typing.List[typing.Tuple[int, int]]) == False + assert validate_type({1: "one", 2: "two"}, typing.Dict[int, str]) == True + assert validate_type({1: "one", 2: 2}, typing.Dict[int, str]) == False + assert validate_type([(1, "one"), (2, "two")], typing.List[typing.Tuple[int, str]]) == True + assert validate_type([(1, "one"), (2, 2)], typing.List[typing.Tuple[int, str]]) == False + assert validate_type({1: [1, 2], 2: [3, 4]}, typing.Dict[int, typing.List[int]]) == True + assert validate_type({1: [1, 2], 2: [3, "4"]}, typing.Dict[int, typing.List[int]]) == False + assert validate_type(3.14, float) == True + assert validate_type(3.14, int) == False + assert validate_type("3.14", float) == False + assert validate_type(b"bytes", bytes) == True + assert validate_type(b"bytes", str) == False + assert validate_type({"a": 1, "b": 2}, typing.Dict[str, int]) == True + assert validate_type({"a": 1, "b": "2"}, typing.Dict[str, int]) == False + assert validate_type([(1, "a"), (2, "b")], typing.List[typing.Tuple[int, str]]) == True + assert validate_type([(1, "a"), (2, 2)], typing.List[typing.Tuple[int, str]]) == False + assert validate_type([(1, 2), (3, 4)], typing.List[typing.Tuple[int, int]]) == True + assert validate_type([(1, 2), (3, "4")], typing.List[typing.Tuple[int, int]]) == False + assert validate_type({1: "one", 2: "two"}, typing.Dict[int, str]) == True + assert validate_type({1: "one", 2: 2}, typing.Dict[int, str]) == False + assert validate_type([(1, "one"), (2, "two")], typing.List[typing.Tuple[int, str]]) == True + assert validate_type([(1, "one"), (2, 2)], typing.List[typing.Tuple[int, str]]) == False + assert validate_type({1: [1, 2], 2: [3, 4]}, typing.Dict[int, typing.List[int]]) == True + assert validate_type({1: [1, 2], 2: [3, "4"]}, typing.Dict[int, typing.List[int]]) == False + assert validate_type(3.14, float) == True + assert validate_type(3.14, int) == False + assert validate_type("3.14", float) == False + assert validate_type(b"bytes", bytes) == True + assert validate_type(b"bytes", str) == False + assert validate_type({"a": 1, "b": 2}, typing.Dict[str, int]) == True + assert validate_type({"a": 1, "b": "2"}, typing.Dict[str, int]) == False + assert validate_type([(1, "a"), (2, "b")], typing.List[typing.Tuple[int, str]]) == True + assert validate_type([(1, "a"), (2, 2)], typing.List[typing.Tuple[int, str]]) == False + assert validate_type([(1, 2), (3, 4)], typing.List[typing.Tuple[int, int]]) == True + assert validate_type([(1, 2), (3, "4")], typing.List[typing.Tuple[int, int]]) == False + assert validate_type({1: "one", 2: "two"}, typing.Dict[int, str]) == True + assert validate_type({1: "one", 2: 2}, typing.Dict[int, str]) == False + assert validate_type([(1, "one"), (2, "two")], typing.List[typing.Tuple[int, str]]) == True + assert validate_type([(1, "one"), (2, 2)], typing.List[typing.Tuple[int, str]]) == False + assert validate_type({1: [1, 2], 2: [3, 4]}, typing.Dict[int, typing.List[int]]) == True + assert validate_type({1: [1, 2], 2: [3, "4"]}, typing.Dict[int, typing.List[int]]) == False + assert validate_type(3.14, float) == True + assert validate_type(3.14, int) == False + assert validate_type("3.14", float) == False + assert validate_type(b"bytes", bytes) == True + assert validate_type(b"bytes", str) == False + assert validate_type({"a": 1, "b": 2}, typing.Dict[str, int]) == True + assert validate_type({"a": 1, "b": "2"}, typing.Dict[str, int]) == False + assert validate_type([(1, "a"), (2, "b")], typing.List[typing.Tuple[int, str]]) == True From 7874d63d3706fcc4620c6e0954b30eb401cb9aef Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 13:47:34 -0700 Subject: [PATCH 048/158] Wip --- tests/unit/test_validate_type.py | 268 ++++++++++++++++++------------- 1 file changed, 160 insertions(+), 108 deletions(-) diff --git a/tests/unit/test_validate_type.py b/tests/unit/test_validate_type.py index c899cc44..1cfe4e51 100644 --- a/tests/unit/test_validate_type.py +++ b/tests/unit/test_validate_type.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest import typing @@ -7,61 +9,154 @@ import typing from typing import Any, Dict, List, Set, Tuple, Union +# Tests for basic types and common use cases +@pytest.mark.parametrize("value, expected_type, result", [ + (42, int, True), + (3.14, float, True), + (5, int, True), + (5.0, int, False), + ("hello", str, True), + (True, bool, True), + (None, type(None), True), + (None, int, False), + ([1, 2, 3], list, True), + ([1, 2, 3], List, True), + ({'a': 1, 'b': 2}, dict, True), + ({'a': 1, 'b': 2}, Dict, True), + ({1, 2, 3}, set, True), + ({1, 2, 3}, Set, True), + ((1, 2, 3), tuple, True), + ((1, 2, 3), Tuple, True), + (b"bytes", bytes, True), + (b"bytes", str, False), + ("3.14", float, False), + ("hello", Any, True), + (5, Any, True), + (3.14, Any, True), +]) +def test_validate_type_basic(value, expected_type, result): + assert validate_type(value, expected_type) == result -def testvalidate_type_basic_types(): - assert validate_type(42, int) - assert validate_type(3.14, float) - assert validate_type("hello", str) - assert validate_type(True, bool) - assert validate_type([1, 2, 3], list) - assert validate_type({'a': 1, 'b': 2}, dict) - assert validate_type({1, 2, 3}, set) - assert validate_type((1, 2, 3), tuple) - -def testvalidate_type_any(): - assert validate_type(42, Any) - assert validate_type("hello", Any) - assert validate_type([1, 2, 3], Any) - -def testvalidate_type_union(): - assert validate_type(42, Union[int, str]) - assert validate_type("hello", Union[int, str]) - assert validate_type(3.14, Union[int, float]) - assert not validate_type(True, Union[int, str]) - -def testvalidate_type_list(): - assert validate_type([1, 2, 3], List[int]) - assert validate_type(["a", "b", "c"], List[str]) - assert not validate_type([1, "a", 3], List[int]) - assert not validate_type(42, List[int]) -def testvalidate_type_dict(): - assert validate_type({'a': 1, 'b': 2}, Dict[str, int]) - assert validate_type({1: 'a', 2: 'b'}, Dict[int, str]) - assert not validate_type({'a': 1, 'b': 'c'}, Dict[str, int]) - assert not validate_type([('a', 1), ('b', 2)], Dict[str, int]) +@pytest.mark.parametrize("value", [ + 42, + "hello", + 3.14, + True, + None, + [1, 2, 3], + {'a': 1, 'b': 2}, + {1, 2, 3}, + (1, 2, 3), + b"bytes", + "3.14", +]) +def test_validate_type_any(value): + assert validate_type(value, Any) + +@pytest.mark.parametrize("value, expected_type, result", [ + (42, Union[int, str], True), + ("hello", Union[int, str], True), + (3.14, Union[int, float], True), + (True, Union[int, str], False), + (None, Union[int, type(None)], True), + (None, Union[int, str], False), + (5, Union[int, str], True), + (5.0, Union[int, str], True), + ("hello", Union[int, str], True), +]) +def test_validate_type_union(value, expected_type, result): + assert validate_type(value, expected_type) == result + -def testvalidate_type_set(): - assert validate_type({1, 2, 3}, Set[int]) - assert validate_type({"a", "b", "c"}, Set[str]) - assert not validate_type({1, "a", 3}, Set[int]) - assert not validate_type([1, 2, 3], Set[int]) +@pytest.mark.parametrize("value, expected_type, result", [ + (42, List[int], False), + ([1, 2, 3], List[int], True), + ([1, 2, 3], List[str], False), + (["a", "b", "c"], List[str], True), + ([1, "a", 3], List[int], False), + (42, List[int], False), + ([1, 2, 3], List[int], True), + ([1, "2", 3], List[int], False), +]) +def test_validate_type_list(value, expected_type, result): + assert validate_type(value, expected_type) == result -def testvalidate_type_tuple(): + +@pytest.mark.parametrize("value, expected_type, result", [ + (42, Dict[str, int], False), + ({'a': 1, 'b': 2}, Dict[str, int], True), + ({'a': 1, 'b': 2}, Dict[int, str], False), + ({1: 'a', 2: 'b'}, Dict[int, str], True), + ({1: 'a', 2: 'b'}, Dict[str, int], False), + ({'a': 1, 'b': 'c'}, Dict[str, int], False), + ([('a', 1), ('b', 2)], Dict[str, int], False), + ({"key": "value"}, Dict[str, str], True), + ({"key": 2}, Dict[str, str], False), + ({"key": 2}, Dict[str, int], True), + ({"key": 2.0}, Dict[str, int], False), + ({"a": 1, "b": 2}, Dict[str, int], True), + ({"a": 1, "b": "2"}, Dict[str, int], False), +]) +def test_validate_type_dict(value, expected_type, result): + assert validate_type(value, expected_type) == result + +@pytest.mark.parametrize("value, expected_type, result", [ + (42, Set[int], False), + ({1, 2, 3}, Set[int], True), + ({1, 2, 3}, Set[str], False), + ({"a", "b", "c"}, Set[str], True), + ({1, "a", 3}, Set[int], False), + (42, Set[int], False), + ({1, 2, 3}, Set[int], True), + ({1, "2", 3}, Set[int], False), + ([1, 2, 3], Set[int], False), + ("hello", Set[str], False), +]) +def test_validate_type_set(value, expected_type, result): + assert validate_type(value, expected_type) == result + +def test_validate_type_tuple(): + assert validate_type((1, "a"), typing.Tuple[int, str]) + assert validate_type((1, 2), typing.Tuple[int, str]) == False + assert validate_type((1, 2), typing.Tuple[int, int]) + assert validate_type((1, 2, 3), typing.Tuple[int, int]) == False assert validate_type((1, 'a', 3.14), Tuple[int, str, float]) assert validate_type(('a', 'b', 'c'), Tuple[str, str, str]) assert not validate_type((1, 'a', 3.14), Tuple[int, str]) assert not validate_type([1, 'a', 3.14], Tuple[int, str, float]) -def testvalidate_type_unsupported_type_hint(): +def test_validate_type_union(): + assert validate_type(5, typing.Union[int, str]) + assert validate_type("hello", typing.Union[int, str]) + assert validate_type(5.0, typing.Union[int, str]) == False + assert validate_type(None, typing.Union[int, type(None)]) + assert validate_type(None, typing.Union[int, type(None), str]) + assert validate_type(None, typing.Union[int, str]) == False + + +def test_validate_type_unsupported_type_hint(): with pytest.raises(ValueError, match="Unsupported type hint"): validate_type(42, typing.Callable[[], None]) -def testvalidate_type_unsupported_generic_alias(): +def test_validate_type_unsupported_generic_alias(): with pytest.raises(ValueError, match="Unsupported generic alias"): validate_type([1, 2, 3], List[int, str]) -def testvalidate_type_edge_cases(): +@pytest.mark.parametrize("value, expected_type, expected_result", [ + ([1, 2, 3], List[int], True), + (["a", "b", "c"], List[str], True), + ([1, "a", 3], List[int], False), + ([1, 2, [3, 4]], List[Union[int, List[int]]], True), + ([(1, 2), (3, 4)], List[Tuple[int, int]], True), + ([(1, 2), (3, "4")], List[Tuple[int, int]], False), + ({1: [1, 2], 2: [3, 4]}, Dict[int, List[int]], True), + ({1: [1, 2], 2: [3, "4"]}, Dict[int, List[int]], False), +]) +def test_validate_type_collections(value, expected_type, expected_result): + assert validate_type(value, expected_type) == expected_result + +def test_validate_type_edge_cases(): assert validate_type([], List[int]) assert validate_type({}, Dict[str, int]) assert validate_type(set(), Set[int]) @@ -76,78 +171,35 @@ def testvalidate_type_edge_cases(): assert validate_type({1, (2, 3)}, Set[Union[int, Tuple[int, int]]]) assert validate_type((1, ('a', 'b')), Tuple[int, Tuple[str, str]]) -def testvalidate_type(): - assert validate_type(5, int) == True - assert validate_type(5.0, int) == False - assert validate_type("hello", str) == True - assert validate_type("hello", typing.Any) == True - assert validate_type([1, 2, 3], typing.List[int]) == True - assert validate_type([1, "2", 3], typing.List[int]) == False - assert validate_type({"key": "value"}, typing.Dict[str, str]) == True - assert validate_type({"key": 2}, typing.Dict[str, str]) == False - assert validate_type((1, "a"), typing.Tuple[int, str]) == True - assert validate_type((1, 2), typing.Tuple[int, str]) == False - assert validate_type((1, 2), typing.Tuple[int, int]) == True - assert validate_type((1, 2, 3), typing.Tuple[int, int]) == False - assert validate_type(5, typing.Union[int, str]) == True - assert validate_type("hello", typing.Union[int, str]) == True - assert validate_type(5.0, typing.Union[int, str]) == False - assert validate_type(None, typing.Union[int, type(None)]) == True - assert validate_type(None, typing.Union[int, type(None), str]) == True - assert validate_type(None, typing.Union[int, str]) == False - assert validate_type({"key": 2}, typing.Dict[str, int]) == True - assert validate_type({"key": 2.0}, typing.Dict[str, int]) == False - assert validate_type([{"key": "value"}], typing.List[typing.Dict[str, str]]) == True +def test_validate_type_complex(): + assert validate_type([{"key": "value"}], typing.List[typing.Dict[str, str]]) assert validate_type([{"key": 2}], typing.List[typing.Dict[str, str]]) == False - assert validate_type([[1, 2], [3, 4]], typing.List[typing.List[int]]) == True + assert validate_type([[1, 2], [3, 4]], typing.List[typing.List[int]]) assert validate_type([[1, 2], [3, "4"]], typing.List[typing.List[int]]) == False - assert validate_type([(1, 2), (3, 4)], typing.List[typing.Tuple[int, int]]) == True - assert validate_type([(1, 2), (3, "4")], typing.List[typing.Tuple[int, int]]) == False - assert validate_type({1: "one", 2: "two"}, typing.Dict[int, str]) == True - assert validate_type({1: "one", 2: 2}, typing.Dict[int, str]) == False - assert validate_type([(1, "one"), (2, "two")], typing.List[typing.Tuple[int, str]]) == True - assert validate_type([(1, "one"), (2, 2)], typing.List[typing.Tuple[int, str]]) == False - assert validate_type({1: [1, 2], 2: [3, 4]}, typing.Dict[int, typing.List[int]]) == True - assert validate_type({1: [1, 2], 2: [3, "4"]}, typing.Dict[int, typing.List[int]]) == False - assert validate_type(3.14, float) == True - assert validate_type(3.14, int) == False - assert validate_type("3.14", float) == False - assert validate_type(b"bytes", bytes) == True - assert validate_type(b"bytes", str) == False - assert validate_type({"a": 1, "b": 2}, typing.Dict[str, int]) == True - assert validate_type({"a": 1, "b": "2"}, typing.Dict[str, int]) == False - assert validate_type([(1, "a"), (2, "b")], typing.List[typing.Tuple[int, str]]) == True - assert validate_type([(1, "a"), (2, 2)], typing.List[typing.Tuple[int, str]]) == False - assert validate_type([(1, 2), (3, 4)], typing.List[typing.Tuple[int, int]]) == True + assert validate_type([(1, 2), (3, 4)], typing.List[typing.Tuple[int, int]]) assert validate_type([(1, 2), (3, "4")], typing.List[typing.Tuple[int, int]]) == False - assert validate_type({1: "one", 2: "two"}, typing.Dict[int, str]) == True + assert validate_type({1: "one", 2: "two"}, typing.Dict[int, str]) assert validate_type({1: "one", 2: 2}, typing.Dict[int, str]) == False - assert validate_type([(1, "one"), (2, "two")], typing.List[typing.Tuple[int, str]]) == True + assert validate_type([(1, "one"), (2, "two")], typing.List[typing.Tuple[int, str]]) assert validate_type([(1, "one"), (2, 2)], typing.List[typing.Tuple[int, str]]) == False - assert validate_type({1: [1, 2], 2: [3, 4]}, typing.Dict[int, typing.List[int]]) == True + assert validate_type({1: [1, 2], 2: [3, 4]}, typing.Dict[int, typing.List[int]]) assert validate_type({1: [1, 2], 2: [3, "4"]}, typing.Dict[int, typing.List[int]]) == False - assert validate_type(3.14, float) == True - assert validate_type(3.14, int) == False - assert validate_type("3.14", float) == False - assert validate_type(b"bytes", bytes) == True - assert validate_type(b"bytes", str) == False - assert validate_type({"a": 1, "b": 2}, typing.Dict[str, int]) == True - assert validate_type({"a": 1, "b": "2"}, typing.Dict[str, int]) == False - assert validate_type([(1, "a"), (2, "b")], typing.List[typing.Tuple[int, str]]) == True + assert validate_type([(1, "a"), (2, "b")], typing.List[typing.Tuple[int, str]]) assert validate_type([(1, "a"), (2, 2)], typing.List[typing.Tuple[int, str]]) == False - assert validate_type([(1, 2), (3, 4)], typing.List[typing.Tuple[int, int]]) == True - assert validate_type([(1, 2), (3, "4")], typing.List[typing.Tuple[int, int]]) == False - assert validate_type({1: "one", 2: "two"}, typing.Dict[int, str]) == True - assert validate_type({1: "one", 2: 2}, typing.Dict[int, str]) == False - assert validate_type([(1, "one"), (2, "two")], typing.List[typing.Tuple[int, str]]) == True - assert validate_type([(1, "one"), (2, 2)], typing.List[typing.Tuple[int, str]]) == False - assert validate_type({1: [1, 2], 2: [3, 4]}, typing.Dict[int, typing.List[int]]) == True - assert validate_type({1: [1, 2], 2: [3, "4"]}, typing.Dict[int, typing.List[int]]) == False - assert validate_type(3.14, float) == True - assert validate_type(3.14, int) == False - assert validate_type("3.14", float) == False - assert validate_type(b"bytes", bytes) == True - assert validate_type(b"bytes", str) == False - assert validate_type({"a": 1, "b": 2}, typing.Dict[str, int]) == True - assert validate_type({"a": 1, "b": "2"}, typing.Dict[str, int]) == False - assert validate_type([(1, "a"), (2, "b")], typing.List[typing.Tuple[int, str]]) == True + + +@pytest.mark.parametrize("value, expected_type, result", [ + ([[[[1]]]], List[List[List[List[int]]]], True), + ([[[[1]]]], List[List[List[List[str]]]], False), + ({"a": {"b": {"c": 1}}}, Dict[str, Dict[str, Dict[str, int]]], True), + ({"a": {"b": {"c": 1}}}, Dict[str, Dict[str, Dict[str, str]]], False), + ({1, 2, 3}, Set[int], True), + ({1, 2, 3}, Set[str], False), + (((1, 2), (3, 4)), Tuple[Tuple[int, int], Tuple[int, int]], True), + (((1, 2), (3, 4)), Tuple[Tuple[int, int], Tuple[int, str]], False), +]) +def test_validate_type_nested(value, expected_type, result): + assert validate_type(value, expected_type) == result + + + \ No newline at end of file From dc209fd8a9d11b3b173cde871e35728a5fb43af9 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 14:48:40 -0700 Subject: [PATCH 049/158] wip --- muutils/validate_type.py | 19 +++-- tests/unit/test_validate_type.py | 121 +++++++++++++++++++++++++------ 2 files changed, 109 insertions(+), 31 deletions(-) diff --git a/muutils/validate_type.py b/muutils/validate_type.py index c6fa6a16..d6d9d2ae 100644 --- a/muutils/validate_type.py +++ b/muutils/validate_type.py @@ -12,27 +12,28 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: origin: type = typing.get_origin(expected_type) args: list = typing.get_args(expected_type) - print(f"{origin = } {args = }") - - if origin is types.UnionType: + if origin is types.UnionType or origin is typing.Union: return any(validate_type(value, arg) for arg in args) # generic alias, more complicated if isinstance(expected_type, (typing.GenericAlias, typing._GenericAlias, typing._UnionGenericAlias)): - + if origin is list: - assert len(args) == 1 + if len(args) != 1: + raise TypeError(f"Too many arguments for list expected 1, got {args = }\n\t{expected_type = }\n\t{value = }") return isinstance(value, list) and all(validate_type(item, args[0]) for item in value) if origin is dict: - assert len(args) == 2 + if len(args) != 2: + raise TypeError(f"Expected 2 arguments for dict, expected 2, got {args = }\n\t{expected_type = }\n\t{value = }") return isinstance(value, dict) and all( validate_type(key, args[0]) and validate_type(val, args[1]) for key, val in value.items() ) if origin is set: - assert len(args) == 1 + if len(args) != 1: + raise TypeError(f"Expected 1 argument for Set, got {args = }\n\t{expected_type = }\n\t{value = }") return isinstance(value, set) and all(validate_type(item, args[0]) for item in value) if origin is tuple: @@ -40,7 +41,9 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: return False return all(validate_type(item, arg) for item, arg in zip(value, args)) - raise ValueError(f"Unsupported generic alias {expected_type}") + # TODO: Callables, etc. + + raise ValueError(f"Unsupported generic alias {expected_type} for {value = }, {origin = }, {args = }") else: raise ValueError(f"Unsupported type hint {expected_type = } for {value = }") \ No newline at end of file diff --git a/tests/unit/test_validate_type.py b/tests/unit/test_validate_type.py index 1cfe4e51..c21668ea 100644 --- a/tests/unit/test_validate_type.py +++ b/tests/unit/test_validate_type.py @@ -7,7 +7,7 @@ from muutils.validate_type import validate_type import types import typing -from typing import Any, Dict, List, Set, Tuple, Union +from typing import Any, Dict, List, Set, Tuple, Union, Optional # Tests for basic types and common use cases @pytest.mark.parametrize("value, expected_type, result", [ @@ -33,6 +33,14 @@ ("hello", Any, True), (5, Any, True), (3.14, Any, True), + # ints + (int(0), int, True), + (int(1), int, True), + (int(-1), int, True), + # bools + (True, bool, True), + (False, bool, True), + ]) def test_validate_type_basic(value, expected_type, result): assert validate_type(value, expected_type) == result @@ -64,10 +72,44 @@ def test_validate_type_any(value): (5, Union[int, str], True), (5.0, Union[int, str], True), ("hello", Union[int, str], True), + (5, typing.Union[int, str], True), + ("hello", typing.Union[int, str], True), + (5.0, typing.Union[int, str], False), + (5, Union[int, str], True), + ("hello", Union[int, str], True), + (5.0, Union[int, str], False), + (5, int|str, True), + ("hello", int|str, True), + (5.0, int|str, False), + (None, typing.Union[int, type(None)], True), + (None, typing.Union[int, str], False), + (None, int|str, False), ]) def test_validate_type_union(value, expected_type, result): assert validate_type(value, expected_type) == result +@pytest.mark.parametrize("value, expected_type, result", [ + (42, Optional[int], True), + ("hello", Optional[int], False), + (3.14, Optional[int], False), + ([1], Optional[List[int]], True), + (None, Optional[int], True), + (None, Optional[str], False), + (None, Optional[int], True), + (None, Optional[None], True), + (None, Optional[list[dict[str, int]]], True), + (42, int|None, True), + ("hello", int|None, False), + (3.14, int|None, False), + ([1], List[int]|None, True), + (None, int|None, True), + (None, str|None, False), + (None, None|str, False), + (None, None|int, True), + (None, None|List[Dict[str, int]], True), +]) +def test_validate_type_optional(value, expected_type, result): + assert validate_type(value, expected_type) == result @pytest.mark.parametrize("value, expected_type, result", [ (42, List[int], False), @@ -83,6 +125,7 @@ def test_validate_type_list(value, expected_type, result): assert validate_type(value, expected_type) == result + @pytest.mark.parametrize("value, expected_type, result", [ (42, Dict[str, int], False), ({'a': 1, 'b': 2}, Dict[str, int], True), @@ -116,32 +159,36 @@ def test_validate_type_dict(value, expected_type, result): def test_validate_type_set(value, expected_type, result): assert validate_type(value, expected_type) == result -def test_validate_type_tuple(): - assert validate_type((1, "a"), typing.Tuple[int, str]) - assert validate_type((1, 2), typing.Tuple[int, str]) == False - assert validate_type((1, 2), typing.Tuple[int, int]) - assert validate_type((1, 2, 3), typing.Tuple[int, int]) == False - assert validate_type((1, 'a', 3.14), Tuple[int, str, float]) - assert validate_type(('a', 'b', 'c'), Tuple[str, str, str]) - assert not validate_type((1, 'a', 3.14), Tuple[int, str]) - assert not validate_type([1, 'a', 3.14], Tuple[int, str, float]) - -def test_validate_type_union(): - assert validate_type(5, typing.Union[int, str]) - assert validate_type("hello", typing.Union[int, str]) - assert validate_type(5.0, typing.Union[int, str]) == False - assert validate_type(None, typing.Union[int, type(None)]) - assert validate_type(None, typing.Union[int, type(None), str]) - assert validate_type(None, typing.Union[int, str]) == False +@pytest.mark.parametrize("value, expected_type, result", [ + (42, Tuple[int, str], False), + ((1, "a"), Tuple[int, str], True), + ((1, 2), Tuple[int, str], False), + ((1, 2), Tuple[int, int], True), + ((1, 2, 3), Tuple[int, int], False), + ((1, 'a', 3.14), Tuple[int, str, float], True), + (('a', 'b', 'c'), Tuple[str, str, str], True), + ((1, 'a', 3.14), Tuple[int, str], False), + ([1, 'a', 3.14], Tuple[int, str, float], False), +]) +def test_validate_type_tuple(value, expected_type, result): + assert validate_type(value, expected_type) == result def test_validate_type_unsupported_type_hint(): - with pytest.raises(ValueError, match="Unsupported type hint"): + with pytest.raises(ValueError): validate_type(42, typing.Callable[[], None]) -def test_validate_type_unsupported_generic_alias(): - with pytest.raises(ValueError, match="Unsupported generic alias"): - validate_type([1, 2, 3], List[int, str]) +@pytest.mark.parametrize("value, expected_type", [ + (42, list[int, str]), + ([1, 2, 3], list[int, str]), + ({"a": 1, "b": 2}, set[str, int, str]), + ({1: "a", 2: "b"}, set[int, str, int]), + ({1, 2, 3}, set[int, str]), + ({"a"}, set[int, str]), +]) +def test_validate_type_unsupported_generic_alias(value, expected_type): + with pytest.raises(TypeError): + validate_type(value, expected_type) @pytest.mark.parametrize("value, expected_type, expected_result", [ ([1, 2, 3], List[int], True), @@ -156,11 +203,39 @@ def test_validate_type_unsupported_generic_alias(): def test_validate_type_collections(value, expected_type, expected_result): assert validate_type(value, expected_type) == expected_result + +@pytest.mark.parametrize("value, expected_type, expected_result", [ + # empty lists + ([], List[int], True), + ([], list[dict], True), + ([], list[tuple[dict[tuple, str], str, None]], True), + # empty dicts + ({}, Dict[str, int], True), + ({}, dict[str, dict], True), + ({}, dict[str, dict[str, int]], True), + ({}, dict[str, dict[str, int]], True), + # empty sets + (set(), Set[int], True), + (set(), set[dict], True), + (set(), set[tuple[dict[tuple, str], str, None]], True), + # empty tuple + (tuple(), tuple, True), + # empty string + ("", str, True), + # empty bytes + (b"", bytes, True), + # None + (None, type(None), True), + # weird floats + (float("nan"), float, True), + (float("inf"), float, True), + (float("-inf"), float, True), + (float(0), float, True), +]) def test_validate_type_edge_cases(): assert validate_type([], List[int]) assert validate_type({}, Dict[str, int]) assert validate_type(set(), Set[int]) - # assert validate_type((), Tuple[]) assert not validate_type(42, List[int]) assert not validate_type("hello", Dict[str, int]) From 5e47f06c79011911d5b6661c52c4d3fac01c44eb Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 15:18:37 -0700 Subject: [PATCH 050/158] fixes! --- muutils/validate_type.py | 24 ++++-- tests/unit/test_validate_type.py | 124 +++++++++++++++++++++++-------- 2 files changed, 112 insertions(+), 36 deletions(-) diff --git a/muutils/validate_type.py b/muutils/validate_type.py index d6d9d2ae..848b13ba 100644 --- a/muutils/validate_type.py +++ b/muutils/validate_type.py @@ -17,33 +17,43 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: # generic alias, more complicated if isinstance(expected_type, (typing.GenericAlias, typing._GenericAlias, typing._UnionGenericAlias)): + + print(f"{value = }, {expected_type = }, {origin = }, {args = }") if origin is list: if len(args) != 1: - raise TypeError(f"Too many arguments for list expected 1, got {args = }\n\t{expected_type = }\n\t{value = }") - return isinstance(value, list) and all(validate_type(item, args[0]) for item in value) + raise TypeError(f"Too many arguments for list expected 1, got {args = }, {expected_type = }, {value = }, {origin = }") + if not isinstance(value, list): + return False + return all(validate_type(item, args[0]) for item in value) if origin is dict: if len(args) != 2: - raise TypeError(f"Expected 2 arguments for dict, expected 2, got {args = }\n\t{expected_type = }\n\t{value = }") - return isinstance(value, dict) and all( + raise TypeError(f"Expected 2 arguments for dict, expected 2, got {args = }, {expected_type = }, {value = }, {origin = }") + if not isinstance(value, dict): + return False + return all( validate_type(key, args[0]) and validate_type(val, args[1]) for key, val in value.items() ) if origin is set: if len(args) != 1: - raise TypeError(f"Expected 1 argument for Set, got {args = }\n\t{expected_type = }\n\t{value = }") - return isinstance(value, set) and all(validate_type(item, args[0]) for item in value) + raise TypeError(f"Expected 1 argument for Set, got {args = }, {expected_type = }, {value = }, {origin = }") + if not isinstance(value, set): + return False + return all(validate_type(item, args[0]) for item in value) if origin is tuple: + if not isinstance(value, tuple): + return False if len(value) != len(args): return False return all(validate_type(item, arg) for item, arg in zip(value, args)) # TODO: Callables, etc. - raise ValueError(f"Unsupported generic alias {expected_type} for {value = }, {origin = }, {args = }") + raise ValueError(f"Unsupported generic alias {expected_type = } for {value = }, {origin = }, {args = }") else: raise ValueError(f"Unsupported type hint {expected_type = } for {value = }") \ No newline at end of file diff --git a/tests/unit/test_validate_type.py b/tests/unit/test_validate_type.py index c21668ea..98980384 100644 --- a/tests/unit/test_validate_type.py +++ b/tests/unit/test_validate_type.py @@ -43,8 +43,10 @@ ]) def test_validate_type_basic(value, expected_type, result): - assert validate_type(value, expected_type) == result - + try: + assert validate_type(value, expected_type) == result + except Exception as e: + raise Exception(f"{value = }, {expected_type = }, {result = }, {e}") from e @pytest.mark.parametrize("value", [ 42, @@ -60,7 +62,10 @@ def test_validate_type_basic(value, expected_type, result): "3.14", ]) def test_validate_type_any(value): - assert validate_type(value, Any) + try: + assert validate_type(value, Any) + except Exception as e: + raise Exception(f"{value = }, expected `Any`, {e}") from e @pytest.mark.parametrize("value, expected_type, result", [ (42, Union[int, str], True), @@ -86,7 +91,10 @@ def test_validate_type_any(value): (None, int|str, False), ]) def test_validate_type_union(value, expected_type, result): - assert validate_type(value, expected_type) == result + try: + assert validate_type(value, expected_type) == result + except Exception as e: + raise Exception(f"{value = }, {expected_type = }, {result = }, {e}") from e @pytest.mark.parametrize("value, expected_type, result", [ (42, Optional[int], True), @@ -109,7 +117,10 @@ def test_validate_type_union(value, expected_type, result): (None, None|List[Dict[str, int]], True), ]) def test_validate_type_optional(value, expected_type, result): - assert validate_type(value, expected_type) == result + try: + assert validate_type(value, expected_type) == result + except Exception as e: + raise Exception(f"{value = }, {expected_type = }, {result = }, {e}") from e @pytest.mark.parametrize("value, expected_type, result", [ (42, List[int], False), @@ -122,11 +133,17 @@ def test_validate_type_optional(value, expected_type, result): ([1, "2", 3], List[int], False), ]) def test_validate_type_list(value, expected_type, result): - assert validate_type(value, expected_type) == result + try: + assert validate_type(value, expected_type) == result + except Exception as e: + raise Exception(f"{value = }, {expected_type = }, {result = }, {e}") from e @pytest.mark.parametrize("value, expected_type, result", [ + (42, dict[str, int], False), + ({'a': 1, 'b': 2}, dict[str, int], True), + ({'a': 1, 'b': 2}, dict[int, str], False), (42, Dict[str, int], False), ({'a': 1, 'b': 2}, Dict[str, int], True), ({'a': 1, 'b': 2}, Dict[int, str], False), @@ -142,9 +159,14 @@ def test_validate_type_list(value, expected_type, result): ({"a": 1, "b": "2"}, Dict[str, int], False), ]) def test_validate_type_dict(value, expected_type, result): - assert validate_type(value, expected_type) == result + try: + assert validate_type(value, expected_type) == result + except Exception as e: + raise Exception(f"{value = }, {expected_type = }, {result = }, {e}") from e @pytest.mark.parametrize("value, expected_type, result", [ + (42, set[int], False), + ({1, 2, 3}, set[int], True), (42, Set[int], False), ({1, 2, 3}, Set[int], True), ({1, 2, 3}, Set[str], False), @@ -157,9 +179,14 @@ def test_validate_type_dict(value, expected_type, result): ("hello", Set[str], False), ]) def test_validate_type_set(value, expected_type, result): - assert validate_type(value, expected_type) == result + try: + assert validate_type(value, expected_type) == result + except Exception as e: + raise Exception(f"{value = }, {expected_type = }, {result = }, {e}") from e @pytest.mark.parametrize("value, expected_type, result", [ + (42, tuple[int, str], False), + ((1, "a"), tuple[int, str], True), (42, Tuple[int, str], False), ((1, "a"), Tuple[int, str], True), ((1, 2), Tuple[int, str], False), @@ -168,19 +195,30 @@ def test_validate_type_set(value, expected_type, result): ((1, 'a', 3.14), Tuple[int, str, float], True), (('a', 'b', 'c'), Tuple[str, str, str], True), ((1, 'a', 3.14), Tuple[int, str], False), - ([1, 'a', 3.14], Tuple[int, str, float], False), + ([1, 'a', 3.14], Tuple[int, str, float], True), ]) def test_validate_type_tuple(value, expected_type, result): - assert validate_type(value, expected_type) == result - + try: + assert validate_type(value, expected_type) == result + except Exception as e: + raise Exception(f"{value = }, {expected_type = }, {result = }, {e}") from e -def test_validate_type_unsupported_type_hint(): +@pytest.mark.parametrize("value, expected_type", [ + (43, typing.Callable), + (lambda x: x, typing.Callable), + (42, typing.Callable[[], None]), + (42, typing.Callable[[int, str], list]), +]) +def test_validate_type_unsupported_type_hint(value, expected_type): with pytest.raises(ValueError): - validate_type(42, typing.Callable[[], None]) + validate_type(value, expected_type) + print(f"Failed to except: {value = }, {expected_type = }") @pytest.mark.parametrize("value, expected_type", [ (42, list[int, str]), ([1, 2, 3], list[int, str]), + ({"a": 1, "b": 2}, set[str, int]), + ({1: "a", 2: "b"}, set[int, str]), ({"a": 1, "b": 2}, set[str, int, str]), ({1: "a", 2: "b"}, set[int, str, int]), ({1, 2, 3}, set[int, str]), @@ -189,8 +227,9 @@ def test_validate_type_unsupported_type_hint(): def test_validate_type_unsupported_generic_alias(value, expected_type): with pytest.raises(TypeError): validate_type(value, expected_type) + print(f"Failed to except: {value = }, {expected_type = }") -@pytest.mark.parametrize("value, expected_type, expected_result", [ +@pytest.mark.parametrize("value, expected_type, result", [ ([1, 2, 3], List[int], True), (["a", "b", "c"], List[str], True), ([1, "a", 3], List[int], False), @@ -200,11 +239,14 @@ def test_validate_type_unsupported_generic_alias(value, expected_type): ({1: [1, 2], 2: [3, 4]}, Dict[int, List[int]], True), ({1: [1, 2], 2: [3, "4"]}, Dict[int, List[int]], False), ]) -def test_validate_type_collections(value, expected_type, expected_result): - assert validate_type(value, expected_type) == expected_result +def test_validate_type_collections(value, expected_type, result): + try: + assert validate_type(value, expected_type) == result + except Exception as e: + raise Exception(f"{value = }, {expected_type = }, {result = }, {e}") from e -@pytest.mark.parametrize("value, expected_type, expected_result", [ +@pytest.mark.parametrize("value, expected_type, result", [ # empty lists ([], List[int], True), ([], list[dict], True), @@ -231,22 +273,43 @@ def test_validate_type_collections(value, expected_type, expected_result): (float("inf"), float, True), (float("-inf"), float, True), (float(0), float, True), + # list/tuple + ([1], tuple[int, int], False), + ((1,2), list[int], False), ]) -def test_validate_type_edge_cases(): - assert validate_type([], List[int]) - assert validate_type({}, Dict[str, int]) - assert validate_type(set(), Set[int]) - - assert not validate_type(42, List[int]) - assert not validate_type("hello", Dict[str, int]) - assert not validate_type([1, 2], Tuple[int, int, int]) - +def test_validate_type_edge_cases(value, expected_type, result): + try: + assert validate_type(value, expected_type) == result + except Exception as e: + raise Exception(f"{value = }, {expected_type = }, {result = }, {e}") from e + + +@pytest.mark.parametrize("value, expected_type, result", [ + (42, list[int], False), + ([1, 2, 3], int, False), + (3.14, tuple[float], False), + (3.14, tuple[float, float], False), + (3.14, tuple[bool, str], False), + (False, tuple[bool, str], False), + (False, tuple[bool], False), + ((False,), tuple[bool], True), + (("abc",), tuple[str], True), + ("test-dict", dict[str, int], False), + ("test-dict", dict, False), +]) +def test_validate_type_wrong_type(value, expected_type, result): + try: + assert validate_type(value, expected_type) == result + except Exception as e: + raise Exception(f"{value = }, {expected_type = }, {result = }, {e}") from e + + + +def test_validate_type_complex(): assert validate_type([1, 2, [3, 4]], List[Union[int, List[int]]]) assert validate_type({'a': 1, 'b': {'c': 2}}, Dict[str, Union[int, Dict[str, int]]]) assert validate_type({1, (2, 3)}, Set[Union[int, Tuple[int, int]]]) assert validate_type((1, ('a', 'b')), Tuple[int, Tuple[str, str]]) - -def test_validate_type_complex(): assert validate_type([{"key": "value"}], typing.List[typing.Dict[str, str]]) assert validate_type([{"key": 2}], typing.List[typing.Dict[str, str]]) == False assert validate_type([[1, 2], [3, 4]], typing.List[typing.List[int]]) @@ -274,7 +337,10 @@ def test_validate_type_complex(): (((1, 2), (3, 4)), Tuple[Tuple[int, int], Tuple[int, str]], False), ]) def test_validate_type_nested(value, expected_type, result): - assert validate_type(value, expected_type) == result + try: + assert validate_type(value, expected_type) == result + except Exception as e: + raise Exception(f"{value = }, {expected_type = }, {result = }, {e}") from e \ No newline at end of file From 85bdb1e7a25c4d2377fa9f4b85a7943f2315b6b3 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 15:24:14 -0700 Subject: [PATCH 051/158] more fixes --- muutils/validate_type.py | 4 +- tests/unit/test_validate_type.py | 97 ++++++++++++++++---------------- 2 files changed, 53 insertions(+), 48 deletions(-) diff --git a/muutils/validate_type.py b/muutils/validate_type.py index 848b13ba..b81f8984 100644 --- a/muutils/validate_type.py +++ b/muutils/validate_type.py @@ -12,13 +12,15 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: origin: type = typing.get_origin(expected_type) args: list = typing.get_args(expected_type) + # useful for debugging + print(f"{value = }, {expected_type = }, {origin = }, {args = }") + if origin is types.UnionType or origin is typing.Union: return any(validate_type(value, arg) for arg in args) # generic alias, more complicated if isinstance(expected_type, (typing.GenericAlias, typing._GenericAlias, typing._UnionGenericAlias)): - print(f"{value = }, {expected_type = }, {origin = }, {args = }") if origin is list: if len(args) != 1: diff --git a/tests/unit/test_validate_type.py b/tests/unit/test_validate_type.py index 98980384..62007680 100644 --- a/tests/unit/test_validate_type.py +++ b/tests/unit/test_validate_type.py @@ -10,7 +10,7 @@ from typing import Any, Dict, List, Set, Tuple, Union, Optional # Tests for basic types and common use cases -@pytest.mark.parametrize("value, expected_type, result", [ +@pytest.mark.parametrize("value, expected_type, expected_result", [ (42, int, True), (3.14, float, True), (5, int, True), @@ -42,11 +42,11 @@ (False, bool, True), ]) -def test_validate_type_basic(value, expected_type, result): +def test_validate_type_basic(value, expected_type, expected_result): try: - assert validate_type(value, expected_type) == result + assert validate_type(value, expected_type) == expected_result except Exception as e: - raise Exception(f"{value = }, {expected_type = }, {result = }, {e}") from e + raise Exception(f"{value = }, {expected_type = }, {expected_result = }, {e}") from e @pytest.mark.parametrize("value", [ 42, @@ -67,7 +67,7 @@ def test_validate_type_any(value): except Exception as e: raise Exception(f"{value = }, expected `Any`, {e}") from e -@pytest.mark.parametrize("value, expected_type, result", [ +@pytest.mark.parametrize("value, expected_type, expected_result", [ (42, Union[int, str], True), ("hello", Union[int, str], True), (3.14, Union[int, float], True), @@ -90,13 +90,13 @@ def test_validate_type_any(value): (None, typing.Union[int, str], False), (None, int|str, False), ]) -def test_validate_type_union(value, expected_type, result): +def test_validate_type_union(value, expected_type, expected_result): try: - assert validate_type(value, expected_type) == result + assert validate_type(value, expected_type) == expected_result except Exception as e: - raise Exception(f"{value = }, {expected_type = }, {result = }, {e}") from e + raise Exception(f"{value = }, {expected_type = }, {expected_result = }, {e}") from e -@pytest.mark.parametrize("value, expected_type, result", [ +@pytest.mark.parametrize("value, expected_type, expected_result", [ (42, Optional[int], True), ("hello", Optional[int], False), (3.14, Optional[int], False), @@ -111,18 +111,19 @@ def test_validate_type_union(value, expected_type, result): (3.14, int|None, False), ([1], List[int]|None, True), (None, int|None, True), - (None, str|None, False), - (None, None|str, False), + (None, str|None, True), + (None, None|str, True), (None, None|int, True), + (None, str|int, False), (None, None|List[Dict[str, int]], True), ]) -def test_validate_type_optional(value, expected_type, result): +def test_validate_type_optional(value, expected_type, expected_result): try: - assert validate_type(value, expected_type) == result + assert validate_type(value, expected_type) == expected_result except Exception as e: - raise Exception(f"{value = }, {expected_type = }, {result = }, {e}") from e + raise Exception(f"{value = }, {expected_type = }, {expected_result = }, {e}") from e -@pytest.mark.parametrize("value, expected_type, result", [ +@pytest.mark.parametrize("value, expected_type, expected_result", [ (42, List[int], False), ([1, 2, 3], List[int], True), ([1, 2, 3], List[str], False), @@ -132,15 +133,15 @@ def test_validate_type_optional(value, expected_type, result): ([1, 2, 3], List[int], True), ([1, "2", 3], List[int], False), ]) -def test_validate_type_list(value, expected_type, result): +def test_validate_type_list(value, expected_type, expected_result): try: - assert validate_type(value, expected_type) == result + assert validate_type(value, expected_type) == expected_result except Exception as e: - raise Exception(f"{value = }, {expected_type = }, {result = }, {e}") from e + raise Exception(f"{value = }, {expected_type = }, {expected_result = }, {e}") from e -@pytest.mark.parametrize("value, expected_type, result", [ +@pytest.mark.parametrize("value, expected_type, expected_result", [ (42, dict[str, int], False), ({'a': 1, 'b': 2}, dict[str, int], True), ({'a': 1, 'b': 2}, dict[int, str], False), @@ -158,13 +159,13 @@ def test_validate_type_list(value, expected_type, result): ({"a": 1, "b": 2}, Dict[str, int], True), ({"a": 1, "b": "2"}, Dict[str, int], False), ]) -def test_validate_type_dict(value, expected_type, result): +def test_validate_type_dict(value, expected_type, expected_result): try: - assert validate_type(value, expected_type) == result + assert validate_type(value, expected_type) == expected_result except Exception as e: - raise Exception(f"{value = }, {expected_type = }, {result = }, {e}") from e + raise Exception(f"{value = }, {expected_type = }, {expected_result = }, {e}") from e -@pytest.mark.parametrize("value, expected_type, result", [ +@pytest.mark.parametrize("value, expected_type, expected_result", [ (42, set[int], False), ({1, 2, 3}, set[int], True), (42, Set[int], False), @@ -178,13 +179,13 @@ def test_validate_type_dict(value, expected_type, result): ([1, 2, 3], Set[int], False), ("hello", Set[str], False), ]) -def test_validate_type_set(value, expected_type, result): +def test_validate_type_set(value, expected_type, expected_result): try: - assert validate_type(value, expected_type) == result + assert validate_type(value, expected_type) == expected_result except Exception as e: - raise Exception(f"{value = }, {expected_type = }, {result = }, {e}") from e + raise Exception(f"{value = }, {expected_type = }, {expected_result = }, {e}") from e -@pytest.mark.parametrize("value, expected_type, result", [ +@pytest.mark.parametrize("value, expected_type, expected_result", [ (42, tuple[int, str], False), ((1, "a"), tuple[int, str], True), (42, Tuple[int, str], False), @@ -195,13 +196,15 @@ def test_validate_type_set(value, expected_type, result): ((1, 'a', 3.14), Tuple[int, str, float], True), (('a', 'b', 'c'), Tuple[str, str, str], True), ((1, 'a', 3.14), Tuple[int, str], False), - ([1, 'a', 3.14], Tuple[int, str, float], True), + ((1, 'a', 3.14), Tuple[int, str, float], True), + ([1, 'a', 3.14], Tuple[int, str, float], False), + ((1, 'a', 3.14, "b", True, None, (1, 2, 3)), Tuple[int, str, float, str, bool, type(None), Tuple[int, int, int]], True), ]) -def test_validate_type_tuple(value, expected_type, result): +def test_validate_type_tuple(value, expected_type, expected_result): try: - assert validate_type(value, expected_type) == result + assert validate_type(value, expected_type) == expected_result except Exception as e: - raise Exception(f"{value = }, {expected_type = }, {result = }, {e}") from e + raise Exception(f"{value = }, {expected_type = }, {expected_result = }, {e}") from e @pytest.mark.parametrize("value, expected_type", [ (43, typing.Callable), @@ -229,7 +232,7 @@ def test_validate_type_unsupported_generic_alias(value, expected_type): validate_type(value, expected_type) print(f"Failed to except: {value = }, {expected_type = }") -@pytest.mark.parametrize("value, expected_type, result", [ +@pytest.mark.parametrize("value, expected_type, expected_result", [ ([1, 2, 3], List[int], True), (["a", "b", "c"], List[str], True), ([1, "a", 3], List[int], False), @@ -239,14 +242,14 @@ def test_validate_type_unsupported_generic_alias(value, expected_type): ({1: [1, 2], 2: [3, 4]}, Dict[int, List[int]], True), ({1: [1, 2], 2: [3, "4"]}, Dict[int, List[int]], False), ]) -def test_validate_type_collections(value, expected_type, result): +def test_validate_type_collections(value, expected_type, expected_result): try: - assert validate_type(value, expected_type) == result + assert validate_type(value, expected_type) == expected_result except Exception as e: - raise Exception(f"{value = }, {expected_type = }, {result = }, {e}") from e + raise Exception(f"{value = }, {expected_type = }, {expected_result = }, {e}") from e -@pytest.mark.parametrize("value, expected_type, result", [ +@pytest.mark.parametrize("value, expected_type, expected_result", [ # empty lists ([], List[int], True), ([], list[dict], True), @@ -277,14 +280,14 @@ def test_validate_type_collections(value, expected_type, result): ([1], tuple[int, int], False), ((1,2), list[int], False), ]) -def test_validate_type_edge_cases(value, expected_type, result): +def test_validate_type_edge_cases(value, expected_type, expected_result): try: - assert validate_type(value, expected_type) == result + assert validate_type(value, expected_type) == expected_result except Exception as e: - raise Exception(f"{value = }, {expected_type = }, {result = }, {e}") from e + raise Exception(f"{value = }, {expected_type = }, {expected_result = }, {e}") from e -@pytest.mark.parametrize("value, expected_type, result", [ +@pytest.mark.parametrize("value, expected_type, expected_result", [ (42, list[int], False), ([1, 2, 3], int, False), (3.14, tuple[float], False), @@ -297,11 +300,11 @@ def test_validate_type_edge_cases(value, expected_type, result): ("test-dict", dict[str, int], False), ("test-dict", dict, False), ]) -def test_validate_type_wrong_type(value, expected_type, result): +def test_validate_type_wrong_type(value, expected_type, expected_result): try: - assert validate_type(value, expected_type) == result + assert validate_type(value, expected_type) == expected_result except Exception as e: - raise Exception(f"{value = }, {expected_type = }, {result = }, {e}") from e + raise Exception(f"{value = }, {expected_type = }, {expected_result = }, {e}") from e @@ -326,7 +329,7 @@ def test_validate_type_complex(): assert validate_type([(1, "a"), (2, 2)], typing.List[typing.Tuple[int, str]]) == False -@pytest.mark.parametrize("value, expected_type, result", [ +@pytest.mark.parametrize("value, expected_type, expected_result", [ ([[[[1]]]], List[List[List[List[int]]]], True), ([[[[1]]]], List[List[List[List[str]]]], False), ({"a": {"b": {"c": 1}}}, Dict[str, Dict[str, Dict[str, int]]], True), @@ -336,11 +339,11 @@ def test_validate_type_complex(): (((1, 2), (3, 4)), Tuple[Tuple[int, int], Tuple[int, int]], True), (((1, 2), (3, 4)), Tuple[Tuple[int, int], Tuple[int, str]], False), ]) -def test_validate_type_nested(value, expected_type, result): +def test_validate_type_nested(value, expected_type, expected_result): try: - assert validate_type(value, expected_type) == result + assert validate_type(value, expected_type) == expected_result except Exception as e: - raise Exception(f"{value = }, {expected_type = }, {result = }, {e}") from e + raise Exception(f"{value = }, {expected_type = }, {expected_result = }, {e}") from e \ No newline at end of file From aa992fa8e2c194d12dca630cd52d0dfb1c5b0786 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 15:35:41 -0700 Subject: [PATCH 052/158] test_validate_type tests pass! --- muutils/validate_type.py | 36 ++++++++++++++++++++++++++++---- tests/unit/test_validate_type.py | 17 ++++++++++++--- 2 files changed, 46 insertions(+), 7 deletions(-) diff --git a/muutils/validate_type.py b/muutils/validate_type.py index b81f8984..84fa4c0e 100644 --- a/muutils/validate_type.py +++ b/muutils/validate_type.py @@ -19,38 +19,66 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: return any(validate_type(value, arg) for arg in args) # generic alias, more complicated - if isinstance(expected_type, (typing.GenericAlias, typing._GenericAlias, typing._UnionGenericAlias)): + if isinstance(expected_type, (typing.GenericAlias, typing._GenericAlias, typing._UnionGenericAlias, typing._BaseGenericAlias)): if origin is list: + # no args + if len(args) == 0: + return isinstance(value, list) + # incorrect number of args if len(args) != 1: raise TypeError(f"Too many arguments for list expected 1, got {args = }, {expected_type = }, {value = }, {origin = }") + # check is list if not isinstance(value, list): return False - return all(validate_type(item, args[0]) for item in value) + # check all items in list are of the correct type + item_type: type = args[0] + return all(validate_type(item, item_type) for item in value) if origin is dict: + # no args + if len(args) == 0: + return isinstance(value, dict) + # incorrect number of args if len(args) != 2: raise TypeError(f"Expected 2 arguments for dict, expected 2, got {args = }, {expected_type = }, {value = }, {origin = }") + # check is dict if not isinstance(value, dict): return False + # check all items in dict are of the correct type + key_type: type = args[0] + value_type: type = args[1] return all( - validate_type(key, args[0]) and validate_type(val, args[1]) + validate_type(key, key_type) and validate_type(val, value_type) for key, val in value.items() ) if origin is set: + # no args + if len(args) == 0: + return isinstance(value, set) + # incorrect number of args if len(args) != 1: raise TypeError(f"Expected 1 argument for Set, got {args = }, {expected_type = }, {value = }, {origin = }") + # check is set if not isinstance(value, set): return False - return all(validate_type(item, args[0]) for item in value) + # check all items in set are of the correct type + item_type: type = args[0] + return all(validate_type(item, item_type) for item in value) if origin is tuple: + # no args + if len(args) == 0: + return isinstance(value, tuple) + # check is tuple if not isinstance(value, tuple): return False + # check correct number of items in tuple if len(value) != len(args): return False + # check all items in tuple are of the correct type return all(validate_type(item, arg) for item, arg in zip(value, args)) # TODO: Callables, etc. diff --git a/tests/unit/test_validate_type.py b/tests/unit/test_validate_type.py index 62007680..6ec42826 100644 --- a/tests/unit/test_validate_type.py +++ b/tests/unit/test_validate_type.py @@ -71,11 +71,11 @@ def test_validate_type_any(value): (42, Union[int, str], True), ("hello", Union[int, str], True), (3.14, Union[int, float], True), - (True, Union[int, str], False), + (True, Union[int, str], True), (None, Union[int, type(None)], True), (None, Union[int, str], False), (5, Union[int, str], True), - (5.0, Union[int, str], True), + (5.0, Union[int, str], False), ("hello", Union[int, str], True), (5, typing.Union[int, str], True), ("hello", typing.Union[int, str], True), @@ -102,7 +102,7 @@ def test_validate_type_union(value, expected_type, expected_result): (3.14, Optional[int], False), ([1], Optional[List[int]], True), (None, Optional[int], True), - (None, Optional[str], False), + (None, Optional[str], True), (None, Optional[int], True), (None, Optional[None], True), (None, Optional[list[dict[str, int]]], True), @@ -271,6 +271,17 @@ def test_validate_type_collections(value, expected_type, expected_result): (b"", bytes, True), # None (None, type(None), True), + # bools are ints, ints are not floats + (True, int, True), + (False, int, True), + (True, float, False), + (False, float, False), + (1, int, True), + (0, int, True), + (1, float, False), + (0, float, False), + (0, bool, False), + (1, bool, False), # weird floats (float("nan"), float, True), (float("inf"), float, True), From c27b86d96457cab805a9dc95e9784818e2895749 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 15:36:06 -0700 Subject: [PATCH 053/158] format --- .../json_serialize/serializable_dataclass.py | 12 +- muutils/validate_type.py | 44 +- tests/unit/test_validate_type.py | 636 ++++++++++-------- 3 files changed, 395 insertions(+), 297 deletions(-) diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index 01f4e426..8fb10944 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -403,14 +403,16 @@ def serializable_dataclass( _cls=None, # type: ignore *, init: bool = True, - repr: bool = True, # TODO: this overrides the actual `repr` method, can this be fixed? + repr: bool = True, # TODO: this overrides the actual `repr` method, can this be fixed? eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, properties_to_serialize: Optional[list[str]] = None, register_handler: bool = True, - on_type_assert: typing.Literal["raise", "warn", "ignore"] = "warn", # TODO: change default to "raise" once more stable + on_type_assert: typing.Literal[ + "raise", "warn", "ignore" + ] = "warn", # TODO: change default to "raise" once more stable **kwargs, ): # -> Union[Callable[[Type[T]], Type[T]], Type[T]]: @@ -566,7 +568,6 @@ def load(cls, data: dict[str, Any] | T) -> Type[T]: # assume no loading needs to happen, keep `value` as-is pass - # store the value in the constructor kwargs ctor_kwargs[field.name] = value @@ -576,7 +577,9 @@ def load(cls, data: dict[str, Any] | T) -> Type[T]: if field_type_hint is not None: # TODO: recursive type hint checking like pydantic? try: - assert _validate_type(ctor_kwargs[field.name], field_type_hint) + assert _validate_type( + ctor_kwargs[field.name], field_type_hint + ) except Exception as e: raise ValueError( f"{field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {ctor_kwargs[field.name] = }" @@ -595,7 +598,6 @@ def load(cls, data: dict[str, Any] | T) -> Type[T]: f"Field '{field.name}' on class {cls} has no type hint, but {field.assert_type = }\n{field = }\n{cls_type_hints = }\n{data = }" ) - return cls(**ctor_kwargs) # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments diff --git a/muutils/validate_type.py b/muutils/validate_type.py index 84fa4c0e..b4688c63 100644 --- a/muutils/validate_type.py +++ b/muutils/validate_type.py @@ -1,17 +1,18 @@ import types import typing + def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: if expected_type is typing.Any: return True - + # base type without args if isinstance(expected_type, type): return isinstance(value, expected_type) origin: type = typing.get_origin(expected_type) args: list = typing.get_args(expected_type) - + # useful for debugging print(f"{value = }, {expected_type = }, {origin = }, {args = }") @@ -19,30 +20,41 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: return any(validate_type(value, arg) for arg in args) # generic alias, more complicated - if isinstance(expected_type, (typing.GenericAlias, typing._GenericAlias, typing._UnionGenericAlias, typing._BaseGenericAlias)): + if isinstance( + expected_type, + ( + typing.GenericAlias, + typing._GenericAlias, + typing._UnionGenericAlias, + typing._BaseGenericAlias, + ), + ): - if origin is list: # no args if len(args) == 0: return isinstance(value, list) # incorrect number of args if len(args) != 1: - raise TypeError(f"Too many arguments for list expected 1, got {args = }, {expected_type = }, {value = }, {origin = }") + raise TypeError( + f"Too many arguments for list expected 1, got {args = }, {expected_type = }, {value = }, {origin = }" + ) # check is list if not isinstance(value, list): return False # check all items in list are of the correct type item_type: type = args[0] return all(validate_type(item, item_type) for item in value) - + if origin is dict: # no args if len(args) == 0: return isinstance(value, dict) # incorrect number of args if len(args) != 2: - raise TypeError(f"Expected 2 arguments for dict, expected 2, got {args = }, {expected_type = }, {value = }, {origin = }") + raise TypeError( + f"Expected 2 arguments for dict, expected 2, got {args = }, {expected_type = }, {value = }, {origin = }" + ) # check is dict if not isinstance(value, dict): return False @@ -53,21 +65,23 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: validate_type(key, key_type) and validate_type(val, value_type) for key, val in value.items() ) - + if origin is set: # no args if len(args) == 0: return isinstance(value, set) # incorrect number of args if len(args) != 1: - raise TypeError(f"Expected 1 argument for Set, got {args = }, {expected_type = }, {value = }, {origin = }") + raise TypeError( + f"Expected 1 argument for Set, got {args = }, {expected_type = }, {value = }, {origin = }" + ) # check is set if not isinstance(value, set): return False # check all items in set are of the correct type item_type: type = args[0] return all(validate_type(item, item_type) for item in value) - + if origin is tuple: # no args if len(args) == 0: @@ -80,10 +94,12 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: return False # check all items in tuple are of the correct type return all(validate_type(item, arg) for item, arg in zip(value, args)) - + # TODO: Callables, etc. - - raise ValueError(f"Unsupported generic alias {expected_type = } for {value = }, {origin = }, {args = }") + + raise ValueError( + f"Unsupported generic alias {expected_type = } for {value = }, {origin = }, {args = }" + ) else: - raise ValueError(f"Unsupported type hint {expected_type = } for {value = }") \ No newline at end of file + raise ValueError(f"Unsupported type hint {expected_type = } for {value = }") diff --git a/tests/unit/test_validate_type.py b/tests/unit/test_validate_type.py index 6ec42826..6ed340ec 100644 --- a/tests/unit/test_validate_type.py +++ b/tests/unit/test_validate_type.py @@ -1,360 +1,440 @@ from __future__ import annotations -import pytest - import typing +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +import pytest from muutils.validate_type import validate_type -import types -import typing -from typing import Any, Dict, List, Set, Tuple, Union, Optional + # Tests for basic types and common use cases -@pytest.mark.parametrize("value, expected_type, expected_result", [ - (42, int, True), - (3.14, float, True), - (5, int, True), - (5.0, int, False), - ("hello", str, True), - (True, bool, True), - (None, type(None), True), - (None, int, False), - ([1, 2, 3], list, True), - ([1, 2, 3], List, True), - ({'a': 1, 'b': 2}, dict, True), - ({'a': 1, 'b': 2}, Dict, True), - ({1, 2, 3}, set, True), - ({1, 2, 3}, Set, True), - ((1, 2, 3), tuple, True), - ((1, 2, 3), Tuple, True), - (b"bytes", bytes, True), - (b"bytes", str, False), - ("3.14", float, False), - ("hello", Any, True), - (5, Any, True), - (3.14, Any, True), - # ints - (int(0), int, True), - (int(1), int, True), - (int(-1), int, True), - # bools - (True, bool, True), - (False, bool, True), - -]) +@pytest.mark.parametrize( + "value, expected_type, expected_result", + [ + (42, int, True), + (3.14, float, True), + (5, int, True), + (5.0, int, False), + ("hello", str, True), + (True, bool, True), + (None, type(None), True), + (None, int, False), + ([1, 2, 3], list, True), + ([1, 2, 3], List, True), + ({"a": 1, "b": 2}, dict, True), + ({"a": 1, "b": 2}, Dict, True), + ({1, 2, 3}, set, True), + ({1, 2, 3}, Set, True), + ((1, 2, 3), tuple, True), + ((1, 2, 3), Tuple, True), + (b"bytes", bytes, True), + (b"bytes", str, False), + ("3.14", float, False), + ("hello", Any, True), + (5, Any, True), + (3.14, Any, True), + # ints + (int(0), int, True), + (int(1), int, True), + (int(-1), int, True), + # bools + (True, bool, True), + (False, bool, True), + ], +) def test_validate_type_basic(value, expected_type, expected_result): try: assert validate_type(value, expected_type) == expected_result except Exception as e: - raise Exception(f"{value = }, {expected_type = }, {expected_result = }, {e}") from e - -@pytest.mark.parametrize("value", [ - 42, - "hello", - 3.14, - True, - None, - [1, 2, 3], - {'a': 1, 'b': 2}, - {1, 2, 3}, - (1, 2, 3), - b"bytes", - "3.14", -]) + raise Exception( + f"{value = }, {expected_type = }, {expected_result = }, {e}" + ) from e + + +@pytest.mark.parametrize( + "value", + [ + 42, + "hello", + 3.14, + True, + None, + [1, 2, 3], + {"a": 1, "b": 2}, + {1, 2, 3}, + (1, 2, 3), + b"bytes", + "3.14", + ], +) def test_validate_type_any(value): try: assert validate_type(value, Any) except Exception as e: raise Exception(f"{value = }, expected `Any`, {e}") from e -@pytest.mark.parametrize("value, expected_type, expected_result", [ - (42, Union[int, str], True), - ("hello", Union[int, str], True), - (3.14, Union[int, float], True), - (True, Union[int, str], True), - (None, Union[int, type(None)], True), - (None, Union[int, str], False), - (5, Union[int, str], True), - (5.0, Union[int, str], False), - ("hello", Union[int, str], True), - (5, typing.Union[int, str], True), - ("hello", typing.Union[int, str], True), - (5.0, typing.Union[int, str], False), - (5, Union[int, str], True), - ("hello", Union[int, str], True), - (5.0, Union[int, str], False), - (5, int|str, True), - ("hello", int|str, True), - (5.0, int|str, False), - (None, typing.Union[int, type(None)], True), - (None, typing.Union[int, str], False), - (None, int|str, False), -]) + +@pytest.mark.parametrize( + "value, expected_type, expected_result", + [ + (42, Union[int, str], True), + ("hello", Union[int, str], True), + (3.14, Union[int, float], True), + (True, Union[int, str], True), + (None, Union[int, type(None)], True), + (None, Union[int, str], False), + (5, Union[int, str], True), + (5.0, Union[int, str], False), + ("hello", Union[int, str], True), + (5, typing.Union[int, str], True), + ("hello", typing.Union[int, str], True), + (5.0, typing.Union[int, str], False), + (5, Union[int, str], True), + ("hello", Union[int, str], True), + (5.0, Union[int, str], False), + (5, int | str, True), + ("hello", int | str, True), + (5.0, int | str, False), + (None, typing.Union[int, type(None)], True), + (None, typing.Union[int, str], False), + (None, int | str, False), + ], +) def test_validate_type_union(value, expected_type, expected_result): try: assert validate_type(value, expected_type) == expected_result except Exception as e: - raise Exception(f"{value = }, {expected_type = }, {expected_result = }, {e}") from e - -@pytest.mark.parametrize("value, expected_type, expected_result", [ - (42, Optional[int], True), - ("hello", Optional[int], False), - (3.14, Optional[int], False), - ([1], Optional[List[int]], True), - (None, Optional[int], True), - (None, Optional[str], True), - (None, Optional[int], True), - (None, Optional[None], True), - (None, Optional[list[dict[str, int]]], True), - (42, int|None, True), - ("hello", int|None, False), - (3.14, int|None, False), - ([1], List[int]|None, True), - (None, int|None, True), - (None, str|None, True), - (None, None|str, True), - (None, None|int, True), - (None, str|int, False), - (None, None|List[Dict[str, int]], True), -]) + raise Exception( + f"{value = }, {expected_type = }, {expected_result = }, {e}" + ) from e + + +@pytest.mark.parametrize( + "value, expected_type, expected_result", + [ + (42, Optional[int], True), + ("hello", Optional[int], False), + (3.14, Optional[int], False), + ([1], Optional[List[int]], True), + (None, Optional[int], True), + (None, Optional[str], True), + (None, Optional[int], True), + (None, Optional[None], True), + (None, Optional[list[dict[str, int]]], True), + (42, int | None, True), + ("hello", int | None, False), + (3.14, int | None, False), + ([1], List[int] | None, True), + (None, int | None, True), + (None, str | None, True), + (None, None | str, True), + (None, None | int, True), + (None, str | int, False), + (None, None | List[Dict[str, int]], True), + ], +) def test_validate_type_optional(value, expected_type, expected_result): try: assert validate_type(value, expected_type) == expected_result except Exception as e: - raise Exception(f"{value = }, {expected_type = }, {expected_result = }, {e}") from e - -@pytest.mark.parametrize("value, expected_type, expected_result", [ - (42, List[int], False), - ([1, 2, 3], List[int], True), - ([1, 2, 3], List[str], False), - (["a", "b", "c"], List[str], True), - ([1, "a", 3], List[int], False), - (42, List[int], False), - ([1, 2, 3], List[int], True), - ([1, "2", 3], List[int], False), -]) + raise Exception( + f"{value = }, {expected_type = }, {expected_result = }, {e}" + ) from e + + +@pytest.mark.parametrize( + "value, expected_type, expected_result", + [ + (42, List[int], False), + ([1, 2, 3], List[int], True), + ([1, 2, 3], List[str], False), + (["a", "b", "c"], List[str], True), + ([1, "a", 3], List[int], False), + (42, List[int], False), + ([1, 2, 3], List[int], True), + ([1, "2", 3], List[int], False), + ], +) def test_validate_type_list(value, expected_type, expected_result): try: assert validate_type(value, expected_type) == expected_result except Exception as e: - raise Exception(f"{value = }, {expected_type = }, {expected_result = }, {e}") from e - - - -@pytest.mark.parametrize("value, expected_type, expected_result", [ - (42, dict[str, int], False), - ({'a': 1, 'b': 2}, dict[str, int], True), - ({'a': 1, 'b': 2}, dict[int, str], False), - (42, Dict[str, int], False), - ({'a': 1, 'b': 2}, Dict[str, int], True), - ({'a': 1, 'b': 2}, Dict[int, str], False), - ({1: 'a', 2: 'b'}, Dict[int, str], True), - ({1: 'a', 2: 'b'}, Dict[str, int], False), - ({'a': 1, 'b': 'c'}, Dict[str, int], False), - ([('a', 1), ('b', 2)], Dict[str, int], False), - ({"key": "value"}, Dict[str, str], True), - ({"key": 2}, Dict[str, str], False), - ({"key": 2}, Dict[str, int], True), - ({"key": 2.0}, Dict[str, int], False), - ({"a": 1, "b": 2}, Dict[str, int], True), - ({"a": 1, "b": "2"}, Dict[str, int], False), -]) + raise Exception( + f"{value = }, {expected_type = }, {expected_result = }, {e}" + ) from e + + +@pytest.mark.parametrize( + "value, expected_type, expected_result", + [ + (42, dict[str, int], False), + ({"a": 1, "b": 2}, dict[str, int], True), + ({"a": 1, "b": 2}, dict[int, str], False), + (42, Dict[str, int], False), + ({"a": 1, "b": 2}, Dict[str, int], True), + ({"a": 1, "b": 2}, Dict[int, str], False), + ({1: "a", 2: "b"}, Dict[int, str], True), + ({1: "a", 2: "b"}, Dict[str, int], False), + ({"a": 1, "b": "c"}, Dict[str, int], False), + ([("a", 1), ("b", 2)], Dict[str, int], False), + ({"key": "value"}, Dict[str, str], True), + ({"key": 2}, Dict[str, str], False), + ({"key": 2}, Dict[str, int], True), + ({"key": 2.0}, Dict[str, int], False), + ({"a": 1, "b": 2}, Dict[str, int], True), + ({"a": 1, "b": "2"}, Dict[str, int], False), + ], +) def test_validate_type_dict(value, expected_type, expected_result): try: assert validate_type(value, expected_type) == expected_result except Exception as e: - raise Exception(f"{value = }, {expected_type = }, {expected_result = }, {e}") from e - -@pytest.mark.parametrize("value, expected_type, expected_result", [ - (42, set[int], False), - ({1, 2, 3}, set[int], True), - (42, Set[int], False), - ({1, 2, 3}, Set[int], True), - ({1, 2, 3}, Set[str], False), - ({"a", "b", "c"}, Set[str], True), - ({1, "a", 3}, Set[int], False), - (42, Set[int], False), - ({1, 2, 3}, Set[int], True), - ({1, "2", 3}, Set[int], False), - ([1, 2, 3], Set[int], False), - ("hello", Set[str], False), -]) + raise Exception( + f"{value = }, {expected_type = }, {expected_result = }, {e}" + ) from e + + +@pytest.mark.parametrize( + "value, expected_type, expected_result", + [ + (42, set[int], False), + ({1, 2, 3}, set[int], True), + (42, Set[int], False), + ({1, 2, 3}, Set[int], True), + ({1, 2, 3}, Set[str], False), + ({"a", "b", "c"}, Set[str], True), + ({1, "a", 3}, Set[int], False), + (42, Set[int], False), + ({1, 2, 3}, Set[int], True), + ({1, "2", 3}, Set[int], False), + ([1, 2, 3], Set[int], False), + ("hello", Set[str], False), + ], +) def test_validate_type_set(value, expected_type, expected_result): try: assert validate_type(value, expected_type) == expected_result except Exception as e: - raise Exception(f"{value = }, {expected_type = }, {expected_result = }, {e}") from e - -@pytest.mark.parametrize("value, expected_type, expected_result", [ - (42, tuple[int, str], False), - ((1, "a"), tuple[int, str], True), - (42, Tuple[int, str], False), - ((1, "a"), Tuple[int, str], True), - ((1, 2), Tuple[int, str], False), - ((1, 2), Tuple[int, int], True), - ((1, 2, 3), Tuple[int, int], False), - ((1, 'a', 3.14), Tuple[int, str, float], True), - (('a', 'b', 'c'), Tuple[str, str, str], True), - ((1, 'a', 3.14), Tuple[int, str], False), - ((1, 'a', 3.14), Tuple[int, str, float], True), - ([1, 'a', 3.14], Tuple[int, str, float], False), - ((1, 'a', 3.14, "b", True, None, (1, 2, 3)), Tuple[int, str, float, str, bool, type(None), Tuple[int, int, int]], True), -]) + raise Exception( + f"{value = }, {expected_type = }, {expected_result = }, {e}" + ) from e + + +@pytest.mark.parametrize( + "value, expected_type, expected_result", + [ + (42, tuple[int, str], False), + ((1, "a"), tuple[int, str], True), + (42, Tuple[int, str], False), + ((1, "a"), Tuple[int, str], True), + ((1, 2), Tuple[int, str], False), + ((1, 2), Tuple[int, int], True), + ((1, 2, 3), Tuple[int, int], False), + ((1, "a", 3.14), Tuple[int, str, float], True), + (("a", "b", "c"), Tuple[str, str, str], True), + ((1, "a", 3.14), Tuple[int, str], False), + ((1, "a", 3.14), Tuple[int, str, float], True), + ([1, "a", 3.14], Tuple[int, str, float], False), + ( + (1, "a", 3.14, "b", True, None, (1, 2, 3)), + Tuple[int, str, float, str, bool, type(None), Tuple[int, int, int]], + True, + ), + ], +) def test_validate_type_tuple(value, expected_type, expected_result): try: assert validate_type(value, expected_type) == expected_result except Exception as e: - raise Exception(f"{value = }, {expected_type = }, {expected_result = }, {e}") from e - -@pytest.mark.parametrize("value, expected_type", [ - (43, typing.Callable), - (lambda x: x, typing.Callable), - (42, typing.Callable[[], None]), - (42, typing.Callable[[int, str], list]), -]) + raise Exception( + f"{value = }, {expected_type = }, {expected_result = }, {e}" + ) from e + + +@pytest.mark.parametrize( + "value, expected_type", + [ + (43, typing.Callable), + (lambda x: x, typing.Callable), + (42, typing.Callable[[], None]), + (42, typing.Callable[[int, str], list]), + ], +) def test_validate_type_unsupported_type_hint(value, expected_type): with pytest.raises(ValueError): validate_type(value, expected_type) print(f"Failed to except: {value = }, {expected_type = }") -@pytest.mark.parametrize("value, expected_type", [ - (42, list[int, str]), - ([1, 2, 3], list[int, str]), - ({"a": 1, "b": 2}, set[str, int]), - ({1: "a", 2: "b"}, set[int, str]), - ({"a": 1, "b": 2}, set[str, int, str]), - ({1: "a", 2: "b"}, set[int, str, int]), - ({1, 2, 3}, set[int, str]), - ({"a"}, set[int, str]), -]) + +@pytest.mark.parametrize( + "value, expected_type", + [ + (42, list[int, str]), + ([1, 2, 3], list[int, str]), + ({"a": 1, "b": 2}, set[str, int]), + ({1: "a", 2: "b"}, set[int, str]), + ({"a": 1, "b": 2}, set[str, int, str]), + ({1: "a", 2: "b"}, set[int, str, int]), + ({1, 2, 3}, set[int, str]), + ({"a"}, set[int, str]), + ], +) def test_validate_type_unsupported_generic_alias(value, expected_type): with pytest.raises(TypeError): validate_type(value, expected_type) print(f"Failed to except: {value = }, {expected_type = }") -@pytest.mark.parametrize("value, expected_type, expected_result", [ - ([1, 2, 3], List[int], True), - (["a", "b", "c"], List[str], True), - ([1, "a", 3], List[int], False), - ([1, 2, [3, 4]], List[Union[int, List[int]]], True), - ([(1, 2), (3, 4)], List[Tuple[int, int]], True), - ([(1, 2), (3, "4")], List[Tuple[int, int]], False), - ({1: [1, 2], 2: [3, 4]}, Dict[int, List[int]], True), - ({1: [1, 2], 2: [3, "4"]}, Dict[int, List[int]], False), -]) + +@pytest.mark.parametrize( + "value, expected_type, expected_result", + [ + ([1, 2, 3], List[int], True), + (["a", "b", "c"], List[str], True), + ([1, "a", 3], List[int], False), + ([1, 2, [3, 4]], List[Union[int, List[int]]], True), + ([(1, 2), (3, 4)], List[Tuple[int, int]], True), + ([(1, 2), (3, "4")], List[Tuple[int, int]], False), + ({1: [1, 2], 2: [3, 4]}, Dict[int, List[int]], True), + ({1: [1, 2], 2: [3, "4"]}, Dict[int, List[int]], False), + ], +) def test_validate_type_collections(value, expected_type, expected_result): try: assert validate_type(value, expected_type) == expected_result except Exception as e: - raise Exception(f"{value = }, {expected_type = }, {expected_result = }, {e}") from e - - -@pytest.mark.parametrize("value, expected_type, expected_result", [ - # empty lists - ([], List[int], True), - ([], list[dict], True), - ([], list[tuple[dict[tuple, str], str, None]], True), - # empty dicts - ({}, Dict[str, int], True), - ({}, dict[str, dict], True), - ({}, dict[str, dict[str, int]], True), - ({}, dict[str, dict[str, int]], True), - # empty sets - (set(), Set[int], True), - (set(), set[dict], True), - (set(), set[tuple[dict[tuple, str], str, None]], True), - # empty tuple - (tuple(), tuple, True), - # empty string - ("", str, True), - # empty bytes - (b"", bytes, True), - # None - (None, type(None), True), - # bools are ints, ints are not floats - (True, int, True), - (False, int, True), - (True, float, False), - (False, float, False), - (1, int, True), - (0, int, True), - (1, float, False), - (0, float, False), - (0, bool, False), - (1, bool, False), - # weird floats - (float("nan"), float, True), - (float("inf"), float, True), - (float("-inf"), float, True), - (float(0), float, True), - # list/tuple - ([1], tuple[int, int], False), - ((1,2), list[int], False), -]) + raise Exception( + f"{value = }, {expected_type = }, {expected_result = }, {e}" + ) from e + + +@pytest.mark.parametrize( + "value, expected_type, expected_result", + [ + # empty lists + ([], List[int], True), + ([], list[dict], True), + ([], list[tuple[dict[tuple, str], str, None]], True), + # empty dicts + ({}, Dict[str, int], True), + ({}, dict[str, dict], True), + ({}, dict[str, dict[str, int]], True), + ({}, dict[str, dict[str, int]], True), + # empty sets + (set(), Set[int], True), + (set(), set[dict], True), + (set(), set[tuple[dict[tuple, str], str, None]], True), + # empty tuple + (tuple(), tuple, True), + # empty string + ("", str, True), + # empty bytes + (b"", bytes, True), + # None + (None, type(None), True), + # bools are ints, ints are not floats + (True, int, True), + (False, int, True), + (True, float, False), + (False, float, False), + (1, int, True), + (0, int, True), + (1, float, False), + (0, float, False), + (0, bool, False), + (1, bool, False), + # weird floats + (float("nan"), float, True), + (float("inf"), float, True), + (float("-inf"), float, True), + (float(0), float, True), + # list/tuple + ([1], tuple[int, int], False), + ((1, 2), list[int], False), + ], +) def test_validate_type_edge_cases(value, expected_type, expected_result): try: assert validate_type(value, expected_type) == expected_result except Exception as e: - raise Exception(f"{value = }, {expected_type = }, {expected_result = }, {e}") from e - - -@pytest.mark.parametrize("value, expected_type, expected_result", [ - (42, list[int], False), - ([1, 2, 3], int, False), - (3.14, tuple[float], False), - (3.14, tuple[float, float], False), - (3.14, tuple[bool, str], False), - (False, tuple[bool, str], False), - (False, tuple[bool], False), - ((False,), tuple[bool], True), - (("abc",), tuple[str], True), - ("test-dict", dict[str, int], False), - ("test-dict", dict, False), -]) + raise Exception( + f"{value = }, {expected_type = }, {expected_result = }, {e}" + ) from e + + +@pytest.mark.parametrize( + "value, expected_type, expected_result", + [ + (42, list[int], False), + ([1, 2, 3], int, False), + (3.14, tuple[float], False), + (3.14, tuple[float, float], False), + (3.14, tuple[bool, str], False), + (False, tuple[bool, str], False), + (False, tuple[bool], False), + ((False,), tuple[bool], True), + (("abc",), tuple[str], True), + ("test-dict", dict[str, int], False), + ("test-dict", dict, False), + ], +) def test_validate_type_wrong_type(value, expected_type, expected_result): try: assert validate_type(value, expected_type) == expected_result except Exception as e: - raise Exception(f"{value = }, {expected_type = }, {expected_result = }, {e}") from e - + raise Exception( + f"{value = }, {expected_type = }, {expected_result = }, {e}" + ) from e def test_validate_type_complex(): assert validate_type([1, 2, [3, 4]], List[Union[int, List[int]]]) - assert validate_type({'a': 1, 'b': {'c': 2}}, Dict[str, Union[int, Dict[str, int]]]) + assert validate_type({"a": 1, "b": {"c": 2}}, Dict[str, Union[int, Dict[str, int]]]) assert validate_type({1, (2, 3)}, Set[Union[int, Tuple[int, int]]]) - assert validate_type((1, ('a', 'b')), Tuple[int, Tuple[str, str]]) + assert validate_type((1, ("a", "b")), Tuple[int, Tuple[str, str]]) assert validate_type([{"key": "value"}], typing.List[typing.Dict[str, str]]) assert validate_type([{"key": 2}], typing.List[typing.Dict[str, str]]) == False assert validate_type([[1, 2], [3, 4]], typing.List[typing.List[int]]) assert validate_type([[1, 2], [3, "4"]], typing.List[typing.List[int]]) == False assert validate_type([(1, 2), (3, 4)], typing.List[typing.Tuple[int, int]]) - assert validate_type([(1, 2), (3, "4")], typing.List[typing.Tuple[int, int]]) == False + assert ( + validate_type([(1, 2), (3, "4")], typing.List[typing.Tuple[int, int]]) == False + ) assert validate_type({1: "one", 2: "two"}, typing.Dict[int, str]) assert validate_type({1: "one", 2: 2}, typing.Dict[int, str]) == False assert validate_type([(1, "one"), (2, "two")], typing.List[typing.Tuple[int, str]]) - assert validate_type([(1, "one"), (2, 2)], typing.List[typing.Tuple[int, str]]) == False + assert ( + validate_type([(1, "one"), (2, 2)], typing.List[typing.Tuple[int, str]]) + == False + ) assert validate_type({1: [1, 2], 2: [3, 4]}, typing.Dict[int, typing.List[int]]) - assert validate_type({1: [1, 2], 2: [3, "4"]}, typing.Dict[int, typing.List[int]]) == False + assert ( + validate_type({1: [1, 2], 2: [3, "4"]}, typing.Dict[int, typing.List[int]]) + == False + ) assert validate_type([(1, "a"), (2, "b")], typing.List[typing.Tuple[int, str]]) - assert validate_type([(1, "a"), (2, 2)], typing.List[typing.Tuple[int, str]]) == False - - -@pytest.mark.parametrize("value, expected_type, expected_result", [ - ([[[[1]]]], List[List[List[List[int]]]], True), - ([[[[1]]]], List[List[List[List[str]]]], False), - ({"a": {"b": {"c": 1}}}, Dict[str, Dict[str, Dict[str, int]]], True), - ({"a": {"b": {"c": 1}}}, Dict[str, Dict[str, Dict[str, str]]], False), - ({1, 2, 3}, Set[int], True), - ({1, 2, 3}, Set[str], False), - (((1, 2), (3, 4)), Tuple[Tuple[int, int], Tuple[int, int]], True), - (((1, 2), (3, 4)), Tuple[Tuple[int, int], Tuple[int, str]], False), -]) + assert ( + validate_type([(1, "a"), (2, 2)], typing.List[typing.Tuple[int, str]]) == False + ) + + +@pytest.mark.parametrize( + "value, expected_type, expected_result", + [ + ([[[[1]]]], List[List[List[List[int]]]], True), + ([[[[1]]]], List[List[List[List[str]]]], False), + ({"a": {"b": {"c": 1}}}, Dict[str, Dict[str, Dict[str, int]]], True), + ({"a": {"b": {"c": 1}}}, Dict[str, Dict[str, Dict[str, str]]], False), + ({1, 2, 3}, Set[int], True), + ({1, 2, 3}, Set[str], False), + (((1, 2), (3, 4)), Tuple[Tuple[int, int], Tuple[int, int]], True), + (((1, 2), (3, 4)), Tuple[Tuple[int, int], Tuple[int, str]], False), + ], +) def test_validate_type_nested(value, expected_type, expected_result): try: assert validate_type(value, expected_type) == expected_result except Exception as e: - raise Exception(f"{value = }, {expected_type = }, {expected_result = }, {e}") from e - - - \ No newline at end of file + raise Exception( + f"{value = }, {expected_type = }, {expected_result = }, {e}" + ) from e From 65c39a36043f89d9a1b75d569687f1cdd48f9f40 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 15:38:45 -0700 Subject: [PATCH 054/158] fixed incorrect validate_type call, removed debug print from that function --- muutils/json_serialize/serializable_dataclass.py | 4 +++- muutils/validate_type.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index 8fb10944..1a5dbe06 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -9,6 +9,8 @@ import warnings from typing import Any, Callable, Optional, Type, TypeVar, Union +from muutils.validate_type import validate_type + # pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access @@ -577,7 +579,7 @@ def load(cls, data: dict[str, Any] | T) -> Type[T]: if field_type_hint is not None: # TODO: recursive type hint checking like pydantic? try: - assert _validate_type( + assert validate_type( ctor_kwargs[field.name], field_type_hint ) except Exception as e: diff --git a/muutils/validate_type.py b/muutils/validate_type.py index b4688c63..e71fc4e3 100644 --- a/muutils/validate_type.py +++ b/muutils/validate_type.py @@ -14,7 +14,7 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: args: list = typing.get_args(expected_type) # useful for debugging - print(f"{value = }, {expected_type = }, {origin = }, {args = }") + # print(f"{value = }, {expected_type = }, {origin = }, {args = }") if origin is types.UnionType or origin is typing.Union: return any(validate_type(value, arg) for arg in args) From 5f8947ded5abc695c2e7a8c713d5c8b276b24f60 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 15:59:41 -0700 Subject: [PATCH 055/158] strict on warnings mode, zanj as optional dep, fix some warnings - `make test WARN_STRICT=1` will now set `-W error` when calling pytest - added an additional step with strict warnings in CI - check for some warnings in configure_notebook (when missing format) - still getting a weird unraisable exception due to files?? idk --- makefile | 21 +++++++++++++++------ poetry.lock | 20 +++++++++++++++++++- pyproject.toml | 2 ++ 3 files changed, 36 insertions(+), 7 deletions(-) diff --git a/makefile b/makefile index d0b2a866..b2a06d5b 100644 --- a/makefile +++ b/makefile @@ -23,7 +23,6 @@ COMMIT_LOG_SINCE_LAST_VERSION := $(shell (git log $(LAST_VERSION)..HEAD --pretty TYPECHECK_COMPAT_ARGS := --disable-error-code misc --disable-error-code syntax --disable-error-code import-not-found - .PHONY: default default: help @@ -54,17 +53,26 @@ check-format: python -m isort --check-only . python -m black --check . -# coverage reports +# pytest options and coverage # -------------------------------------------------- + +PYTEST_OPTIONS ?= + # whether to run pytest with coverage report generation COV ?= 1 -ifeq ($(COV),1) - PYTEST_OPTIONS=--cov=. -else - PYTEST_OPTIONS= +ifneq ($(COV), 0) + PYTEST_OPTIONS += --cov=. endif +# whether to run pytest with warnings as errors +WARN_STRICT ?= 0 + +ifneq ($(WARN_STRICT), 0) + PYTEST_OPTIONS += -W error +endif + + .PHONY: cov cov: @echo "generate coverage reports" @@ -95,6 +103,7 @@ typing-compat: clean .PHONY: test test: clean @echo "running tests" + @echo "pytest options: $(PYTEST_OPTIONS)" $(PYPOETRY) -m pytest $(PYTEST_OPTIONS) $(TESTS_DIR) diff --git a/poetry.lock b/poetry.lock index 691fc25e..4ec44ec4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2149,6 +2149,23 @@ files = [ {file = "wrapt-1.16.0.tar.gz", hash = "sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d"}, ] +[[package]] +name = "zanj" +version = "0.2.2" +description = "save and load complex objects to disk without pickling" +optional = true +python-versions = "<4.0,>=3.10" +files = [ + {file = "zanj-0.2.2-py3-none-any.whl", hash = "sha256:91cce89bf8e7041e8acca3071b935899edad0ff776bf4e7bf0d98cfa6cb28f1d"}, + {file = "zanj-0.2.2.tar.gz", hash = "sha256:71c6110f9b9d1a0fe04c011156b8e2c96f6fac1a6f9dc690b6c380862fd5fb8d"}, +] + +[package.dependencies] +muutils = {version = ">=0.5.1,<0.6.0", extras = ["array"]} + +[package.extras] +pandas = ["pandas (>=1.5.3)"] + [[package]] name = "zipp" version = "3.19.2" @@ -2167,8 +2184,9 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [extras] array = ["jaxtyping", "numpy", "torch"] notebook = ["ipython"] +zanj = ["zanj"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "c2a126688ee5e43af857b80f2d7062ad7858ad4b95e114784e5ab595a780607b" +content-hash = "2ce686a4013fe36f7ba4776adf40fe201c7c7855f92bb9ce1487b8372a3ae9df" diff --git a/pyproject.toml b/pyproject.toml index 55a261de..ac0d21bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,10 +23,12 @@ numpy = { version = "^1.22.4", optional = true } torch = { version = ">=1.13.1", optional = true } jaxtyping = { version = "^0.2.12", optional = true } ipython = { version = "^8.20.0", optional = true, python = "^3.10" } +zanj = { version = "^0.2.2", optional = true, python = "^3.10" } [tool.poetry.extras] array = ["numpy", "torch", "jaxtyping"] notebook = ["ipython"] +zanj = ["zanj"] [tool.poetry.group.dev.dependencies] pytest = "^8.2.2" From 03e0fcd03bae5b57261f983503dcf97bd1f0e7bc Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 16:01:15 -0700 Subject: [PATCH 056/158] whoops, this was meant to be in the last commit --- .github/workflows/checks.yml | 3 ++ tests/unit/nbutils/test_configure_notebook.py | 30 ++++++++++++++----- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index bef6496e..d484b13d 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -64,6 +64,9 @@ jobs: - name: tests run: make test + - name: tests - strict mode + run: make test WARN_STRICT=1 + - name: check typing if: ${{ matrix.versions.python != '3.8' && matrix.versions.python != '3.9' }} run: make typing diff --git a/tests/unit/nbutils/test_configure_notebook.py b/tests/unit/nbutils/test_configure_notebook.py index 293820d9..82a235eb 100644 --- a/tests/unit/nbutils/test_configure_notebook.py +++ b/tests/unit/nbutils/test_configure_notebook.py @@ -44,11 +44,13 @@ def test_configure_notebook(): def test_plotshow_save(): setup_plots(plot_mode="save", fig_basepath=JUNK_DATA_PATH) - plt.plot([1, 2, 3], [1, 2, 3]) - plotshow() + with pytest.warns(UnknownFigureFormatWarning): + plt.plot([1, 2, 3], [1, 2, 3]) + plotshow() assert os.path.exists(os.path.join(JUNK_DATA_PATH, "figure-1.pdf")) - plt.plot([3, 6, 9], [2, 4, 8]) - plotshow() + with pytest.warns(UnknownFigureFormatWarning): + plt.plot([3, 6, 9], [2, 4, 8]) + plotshow() assert os.path.exists(os.path.join(JUNK_DATA_PATH, "figure-2.pdf")) @@ -68,14 +70,16 @@ def test_plotshow_save_mixed(): fig_basepath=JUNK_DATA_PATH, fig_numbered_fname="mixedfig-{num}", ) - plt.plot([1, 2, 3], [1, 2, 3]) - plotshow() + with pytest.warns(UnknownFigureFormatWarning): + plt.plot([1, 2, 3], [1, 2, 3]) + plotshow() assert os.path.exists(os.path.join(JUNK_DATA_PATH, "mixedfig-1.pdf")) plt.plot([3, 6, 9], [2, 4, 8]) plotshow(fname="mixed-test.pdf") assert os.path.exists(os.path.join(JUNK_DATA_PATH, "mixed-test.pdf")) - plt.plot([1, 1, 1], [1, 9, 9]) - plotshow() + with pytest.warns(UnknownFigureFormatWarning): + plt.plot([1, 1, 1], [1, 9, 9]) + plotshow() assert os.path.exists(os.path.join(JUNK_DATA_PATH, "mixedfig-3.pdf")) @@ -89,6 +93,16 @@ def test_warn_unknown_format(): plt.plot([1, 2, 3], [1, 2, 3]) plotshow() +def test_no_warn_unknown_format_2(): + with pytest.warns(UnknownFigureFormatWarning): + setup_plots( + plot_mode="save", + fig_basepath=JUNK_DATA_PATH, + fig_numbered_fname="mixedfig-{num}", + ) + plt.plot([1, 2, 3], [1, 2, 3]) + plotshow("no-format") + def test_no_warn_pdf_format(): with warnings.catch_warnings(): From 3722dcbc31cbcdd7f9629253112433f7864858b7 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 16:01:28 -0700 Subject: [PATCH 057/158] run format --- tests/unit/nbutils/test_configure_notebook.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/nbutils/test_configure_notebook.py b/tests/unit/nbutils/test_configure_notebook.py index 82a235eb..e9116a31 100644 --- a/tests/unit/nbutils/test_configure_notebook.py +++ b/tests/unit/nbutils/test_configure_notebook.py @@ -93,6 +93,7 @@ def test_warn_unknown_format(): plt.plot([1, 2, 3], [1, 2, 3]) plotshow() + def test_no_warn_unknown_format_2(): with pytest.warns(UnknownFigureFormatWarning): setup_plots( From a17a270cf2c5d27166177eb01d7ca0a547fec7cd Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 16:07:43 -0700 Subject: [PATCH 058/158] ignore unraisable pytest exception, fix _popen, minor make fix --- makefile | 1 - muutils/sysinfo.py | 8 +++++--- tests/unit/test_sysinfo.py | 3 +++ 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/makefile b/makefile index b2a06d5b..1f21721c 100644 --- a/makefile +++ b/makefile @@ -103,7 +103,6 @@ typing-compat: clean .PHONY: test test: clean @echo "running tests" - @echo "pytest options: $(PYTEST_OPTIONS)" $(PYPOETRY) -m pytest $(PYTEST_OPTIONS) $(TESTS_DIR) diff --git a/muutils/sysinfo.py b/muutils/sysinfo.py index c3b8d4e4..a37467bf 100644 --- a/muutils/sysinfo.py +++ b/muutils/sysinfo.py @@ -12,9 +12,11 @@ def _popen(cmd: list[str], split_out: bool = False) -> dict[str, typing.Any]: cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) + stdout, stderr = p.communicate() + p_out: typing.Union[str, list[str], None] - if p.stdout is not None: - p_out = p.stdout.read().decode("utf-8") + if stdout: + p_out = stdout.decode("utf-8") if split_out: assert isinstance(p_out, str) p_out = p_out.strip().split("\n") @@ -23,7 +25,7 @@ def _popen(cmd: list[str], split_out: bool = False) -> dict[str, typing.Any]: return { "stdout": p_out, - "stderr": (None if p.stderr is None else p.stderr.read().decode("utf-8")), + "stderr": stderr.decode("utf-8") if stderr else None, "returncode": p.returncode if p.returncode is None else int(p.returncode), } diff --git a/tests/unit/test_sysinfo.py b/tests/unit/test_sysinfo.py index 023d7460..2bc9ec44 100644 --- a/tests/unit/test_sysinfo.py +++ b/tests/unit/test_sysinfo.py @@ -1,6 +1,9 @@ +import pytest + from muutils.sysinfo import SysInfo +@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") def test_sysinfo(): sysinfo = SysInfo.get_all() # we can't test the output because it's different on every machine From 347f069aa51021616f294aabd37146d6b81d54ec Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 16:12:51 -0700 Subject: [PATCH 059/158] fixes to validate_type --- muutils/validate_type.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/muutils/validate_type.py b/muutils/validate_type.py index e71fc4e3..ed276eac 100644 --- a/muutils/validate_type.py +++ b/muutils/validate_type.py @@ -1,8 +1,14 @@ +from __future__ import annotations + import types import typing def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: + """Validate that a value is of the expected type. use `typeguard` for a more robust solution. + + https://github.com/agronholm/typeguard + """ if expected_type is typing.Any: return True From e41dfb3dab820986ee3fb2c6395c75d61eaa408b Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 16:34:21 -0700 Subject: [PATCH 060/158] better dependency stuff, going to use uv for tests in CI --- dev-requirements.txt | 81 ++++++++++++++++ makefile | 13 +++ poetry.lock | 216 ++++--------------------------------------- pyproject.toml | 2 +- 4 files changed, 112 insertions(+), 200 deletions(-) create mode 100644 dev-requirements.txt diff --git a/dev-requirements.txt b/dev-requirements.txt new file mode 100644 index 00000000..1a9326ad --- /dev/null +++ b/dev-requirements.txt @@ -0,0 +1,81 @@ +--extra-index-url https://download.pytorch.org/whl/cpu + +astroid==2.15.8 ; python_version >= "3.8" and python_version < "4.0" +asttokens==2.4.1 ; python_version >= "3.10" and python_version < "4.0" +black==24.4.2 ; python_version >= "3.8" and python_version < "4.0" +click==8.1.7 ; python_version >= "3.8" and python_version < "4.0" +colorama==0.4.6 ; python_version >= "3.8" and python_version < "4.0" and (sys_platform == "win32" or platform_system == "Windows") +contourpy==1.1.1 ; python_version >= "3.8" and python_version < "4.0" +coverage-badge==1.1.1 ; python_version >= "3.8" and python_version < "4.0" +coverage==7.5.3 ; python_version >= "3.8" and python_version < "4.0" +coverage[toml]==7.5.3 ; python_version >= "3.8" and python_version < "4.0" +cycler==0.12.1 ; python_version >= "3.8" and python_version < "4.0" +decorator==5.1.1 ; python_version >= "3.10" and python_version < "4.0" +dill==0.3.8 ; python_version >= "3.8" and python_version < "4.0" +exceptiongroup==1.2.1 ; python_version >= "3.8" and python_version < "3.11" +executing==2.0.1 ; python_version >= "3.10" and python_version < "4.0" +filelock==3.15.1 ; python_version >= "3.8" and python_version < "4.0" +fonttools==4.53.0 ; python_version >= "3.8" and python_version < "4.0" +fsspec==2024.6.0 ; python_version >= "3.8" and python_version < "4.0" +importlib-metadata==7.1.0 ; python_version >= "3.8" and python_version < "3.10" +importlib-resources==6.4.0 ; python_version >= "3.8" and python_version < "3.10" +iniconfig==2.0.0 ; python_version >= "3.8" and python_version < "4.0" +intel-openmp==2021.4.0 ; python_version >= "3.8" and python_version < "4.0" and platform_system == "Windows" +ipython==8.25.0 ; python_version >= "3.10" and python_version < "4.0" +isort==5.13.2 ; python_version >= "3.8" and python_version < "4.0" +jaxtyping==0.2.19 ; python_version >= "3.8" and python_version < "4.0" +jedi==0.19.1 ; python_version >= "3.10" and python_version < "4.0" +jinja2==3.1.4 ; python_version >= "3.8" and python_version < "4.0" +kiwisolver==1.4.5 ; python_version >= "3.8" and python_version < "4.0" +lazy-object-proxy==1.10.0 ; python_version >= "3.8" and python_version < "4.0" +libcst==1.1.0 ; python_version >= "3.8" and python_version < "4" +markdown-it-py==3.0.0 ; python_version >= "3.8" and python_version < "4" +markupsafe==2.1.5 ; python_version >= "3.8" and python_version < "4.0" +matplotlib-inline==0.1.7 ; python_version >= "3.10" and python_version < "4.0" +matplotlib==3.7.5 ; python_version >= "3.8" and python_version < "4.0" +mccabe==0.7.0 ; python_version >= "3.8" and python_version < "4.0" +mdurl==0.1.2 ; python_version >= "3.8" and python_version < "4" +mkl==2021.4.0 ; python_version >= "3.8" and python_version < "4.0" and platform_system == "Windows" +mpmath==1.3.0 ; python_version >= "3.8" and python_version < "4.0" +mypy-extensions==1.0.0 ; python_version >= "3.8" and python_version < "4.0" +mypy==1.10.0 ; python_version >= "3.8" and python_version < "4.0" +networkx==3.1 ; python_version >= "3.8" and python_version < "4.0" +numpy==1.24.4 ; python_version >= "3.8" and python_version < "4.0" +packaging==24.1 ; python_version >= "3.8" and python_version < "4.0" +parso==0.8.4 ; python_version >= "3.10" and python_version < "4.0" +pathspec==0.12.1 ; python_version >= "3.8" and python_version < "4.0" +pexpect==4.9.0 ; python_version >= "3.10" and python_version < "4.0" and (sys_platform != "win32" and sys_platform != "emscripten") +pillow==10.3.0 ; python_version >= "3.8" and python_version < "4.0" +platformdirs==4.2.2 ; python_version >= "3.8" and python_version < "4.0" +plotly==5.22.0 ; python_version >= "3.8" and python_version < "4.0" +pluggy==1.5.0 ; python_version >= "3.8" and python_version < "4.0" +prompt-toolkit==3.0.47 ; python_version >= "3.10" and python_version < "4.0" +ptyprocess==0.7.0 ; python_version >= "3.10" and python_version < "4.0" and (sys_platform != "win32" and sys_platform != "emscripten") +pure-eval==0.2.2 ; python_version >= "3.10" and python_version < "4.0" +pycln==2.4.0 ; python_version >= "3.8" and python_version < "4" +pygments==2.18.0 ; python_version >= "3.8" and python_version < "4.0" +pylint==2.17.7 ; python_version >= "3.8" and python_version < "4.0" +pyparsing==3.1.2 ; python_version >= "3.8" and python_version < "4.0" +pytest-cov==4.1.0 ; python_version >= "3.8" and python_version < "4.0" +pytest==8.2.2 ; python_version >= "3.8" and python_version < "4.0" +python-dateutil==2.9.0.post0 ; python_version >= "3.8" and python_version < "4.0" +pyyaml==6.0.1 ; python_version >= "3.8" and python_version < "4" +rich==13.7.1 ; python_version >= "3.8" and python_version < "4" +shellingham==1.5.4 ; python_version >= "3.8" and python_version < "4" +six==1.16.0 ; python_version >= "3.8" and python_version < "4.0" +stack-data==0.6.3 ; python_version >= "3.10" and python_version < "4.0" +sympy==1.12.1 ; python_version >= "3.8" and python_version < "4.0" +tbb==2021.12.0 ; python_version >= "3.8" and python_version < "4.0" and platform_system == "Windows" +tenacity==8.4.1 ; python_version >= "3.8" and python_version < "4.0" +tomli==2.0.1 ; python_version >= "3.8" and python_version < "3.11" +tomlkit==0.12.5 ; python_version >= "3.8" and python_version < "4.0" +torch==2.3.1+cpu ; python_version >= "3.8" and python_version < "4.0" +traitlets==5.14.3 ; python_version >= "3.10" and python_version < "4.0" +typeguard==4.3.0 ; python_version >= "3.8" and python_version < "4.0" +typer==0.12.3 ; python_version >= "3.8" and python_version < "4" +typing-extensions==4.12.2 ; python_version >= "3.8" and python_version < "4.0" +typing-inspect==0.9.0 ; python_version >= "3.8" and python_version < "4" +wcwidth==0.2.13 ; python_version >= "3.10" and python_version < "4.0" +wrapt==1.16.0 ; python_version >= "3.8" and python_version < "4.0" +zanj==0.2.2 ; python_version >= "3.10" and python_version < "4.0" +zipp==3.19.2 ; python_version >= "3.8" and python_version < "3.10" diff --git a/makefile b/makefile index 1f21721c..9f9a6e40 100644 --- a/makefile +++ b/makefile @@ -125,6 +125,19 @@ verify-git: exit 1; \ fi; \ + +EXPORT_ARGS := -E zanj -E array -E notebook --with dev --without-hashes + +.PHONY: dep-dev +dep-dev: + @echo "exporting dev and extras dependencies to dev-requirements.txt" + poetry export $(EXPORT_ARGS) --output dev-requirements.txt + +.PHONY: check-dep-dev +check-dep-dev: + @echo "checking requirements.txt matches poetry dependencies" + poetry export $(EXPORT_ARGS) | diff - dev-requirements.txt + .PHONY: build build: @echo "build via poetry, assumes checks have been run" diff --git a/poetry.lock b/poetry.lock index 4ec44ec4..61b71fc1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1229,148 +1229,6 @@ files = [ {file = "numpy-1.24.4.tar.gz", hash = "sha256:80f5e3a4e498641401868df4208b74581206afbee7cf7b8329daae82676d9463"}, ] -[[package]] -name = "nvidia-cublas-cu12" -version = "12.1.3.1" -description = "CUBLAS native runtime libraries" -optional = true -python-versions = ">=3" -files = [ - {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728"}, - {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906"}, -] - -[[package]] -name = "nvidia-cuda-cupti-cu12" -version = "12.1.105" -description = "CUDA profiling tools runtime libs." -optional = true -python-versions = ">=3" -files = [ - {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e"}, - {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4"}, -] - -[[package]] -name = "nvidia-cuda-nvrtc-cu12" -version = "12.1.105" -description = "NVRTC native runtime libraries" -optional = true -python-versions = ">=3" -files = [ - {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2"}, - {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed"}, -] - -[[package]] -name = "nvidia-cuda-runtime-cu12" -version = "12.1.105" -description = "CUDA Runtime native Libraries" -optional = true -python-versions = ">=3" -files = [ - {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40"}, - {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344"}, -] - -[[package]] -name = "nvidia-cudnn-cu12" -version = "8.9.2.26" -description = "cuDNN runtime libraries" -optional = true -python-versions = ">=3" -files = [ - {file = "nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl", hash = "sha256:5ccb288774fdfb07a7e7025ffec286971c06d8d7b4fb162525334616d7629ff9"}, -] - -[package.dependencies] -nvidia-cublas-cu12 = "*" - -[[package]] -name = "nvidia-cufft-cu12" -version = "11.0.2.54" -description = "CUFFT native runtime libraries" -optional = true -python-versions = ">=3" -files = [ - {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56"}, - {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253"}, -] - -[[package]] -name = "nvidia-curand-cu12" -version = "10.3.2.106" -description = "CURAND native runtime libraries" -optional = true -python-versions = ">=3" -files = [ - {file = "nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0"}, - {file = "nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a"}, -] - -[[package]] -name = "nvidia-cusolver-cu12" -version = "11.4.5.107" -description = "CUDA solver native runtime libraries" -optional = true -python-versions = ">=3" -files = [ - {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd"}, - {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5"}, -] - -[package.dependencies] -nvidia-cublas-cu12 = "*" -nvidia-cusparse-cu12 = "*" -nvidia-nvjitlink-cu12 = "*" - -[[package]] -name = "nvidia-cusparse-cu12" -version = "12.1.0.106" -description = "CUSPARSE native runtime libraries" -optional = true -python-versions = ">=3" -files = [ - {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c"}, - {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a"}, -] - -[package.dependencies] -nvidia-nvjitlink-cu12 = "*" - -[[package]] -name = "nvidia-nccl-cu12" -version = "2.20.5" -description = "NVIDIA Collective Communication Library (NCCL) Runtime" -optional = true -python-versions = ">=3" -files = [ - {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01"}, - {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:057f6bf9685f75215d0c53bf3ac4a10b3e6578351de307abad9e18a99182af56"}, -] - -[[package]] -name = "nvidia-nvjitlink-cu12" -version = "12.5.40" -description = "Nvidia JIT LTO Library" -optional = true -python-versions = ">=3" -files = [ - {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_x86_64.whl", hash = "sha256:d9714f27c1d0f0895cd8915c07a87a1d0029a0aa36acaf9156952ec2a8a12189"}, - {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-win_amd64.whl", hash = "sha256:c3401dc8543b52d3a8158007a0c1ab4e9c768fcbd24153a48c86972102197ddd"}, -] - -[[package]] -name = "nvidia-nvtx-cu12" -version = "12.1.105" -description = "NVIDIA Tools Extension" -optional = true -python-versions = ">=3" -files = [ - {file = "nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5"}, - {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"}, -] - [[package]] name = "packaging" version = "24.1" @@ -1907,31 +1765,21 @@ files = [ [[package]] name = "torch" -version = "2.3.1" +version = "2.3.1+cpu" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" optional = true python-versions = ">=3.8.0" files = [ - {file = "torch-2.3.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:605a25b23944be5ab7c3467e843580e1d888b8066e5aaf17ff7bf9cc30001cc3"}, - {file = "torch-2.3.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:f2357eb0965583a0954d6f9ad005bba0091f956aef879822274b1bcdb11bd308"}, - {file = "torch-2.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:32b05fe0d1ada7f69c9f86c14ff69b0ef1957a5a54199bacba63d22d8fab720b"}, - {file = "torch-2.3.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:7c09a94362778428484bcf995f6004b04952106aee0ef45ff0b4bab484f5498d"}, - {file = "torch-2.3.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:b2ec81b61bb094ea4a9dee1cd3f7b76a44555375719ad29f05c0ca8ef596ad39"}, - {file = "torch-2.3.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:490cc3d917d1fe0bd027057dfe9941dc1d6d8e3cae76140f5dd9a7e5bc7130ab"}, - {file = "torch-2.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:5802530783bd465fe66c2df99123c9a54be06da118fbd785a25ab0a88123758a"}, - {file = "torch-2.3.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:a7dd4ed388ad1f3d502bf09453d5fe596c7b121de7e0cfaca1e2017782e9bbac"}, - {file = "torch-2.3.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:a486c0b1976a118805fc7c9641d02df7afbb0c21e6b555d3bb985c9f9601b61a"}, - {file = "torch-2.3.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:224259821fe3e4c6f7edf1528e4fe4ac779c77addaa74215eb0b63a5c474d66c"}, - {file = "torch-2.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:e5fdccbf6f1334b2203a61a0e03821d5845f1421defe311dabeae2fc8fbeac2d"}, - {file = "torch-2.3.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:3c333dc2ebc189561514eda06e81df22bf8fb64e2384746b2cb9f04f96d1d4c8"}, - {file = "torch-2.3.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:07e9ba746832b8d069cacb45f312cadd8ad02b81ea527ec9766c0e7404bb3feb"}, - {file = "torch-2.3.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:462d1c07dbf6bb5d9d2f3316fee73a24f3d12cd8dacf681ad46ef6418f7f6626"}, - {file = "torch-2.3.1-cp38-cp38-win_amd64.whl", hash = "sha256:ff60bf7ce3de1d43ad3f6969983f321a31f0a45df3690921720bcad6a8596cc4"}, - {file = "torch-2.3.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:bee0bd33dc58aa8fc8a7527876e9b9a0e812ad08122054a5bff2ce5abf005b10"}, - {file = "torch-2.3.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:aaa872abde9a3d4f91580f6396d54888620f4a0b92e3976a6034759df4b961ad"}, - {file = "torch-2.3.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:3d7a7f7ef21a7520510553dc3938b0c57c116a7daee20736a9e25cbc0e832bdc"}, - {file = "torch-2.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:4777f6cefa0c2b5fa87223c213e7b6f417cf254a45e5829be4ccd1b2a4ee1011"}, - {file = "torch-2.3.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:2bb5af780c55be68fe100feb0528d2edebace1d55cb2e351de735809ba7391eb"}, + {file = "torch-2.3.1+cpu-cp310-cp310-linux_x86_64.whl", hash = "sha256:d679e21d871982b9234444331a26350902cfd2d5ca44ce6f49896af8b3a3087d"}, + {file = "torch-2.3.1+cpu-cp310-cp310-win_amd64.whl", hash = "sha256:500bf790afc2fd374a15d06213242e517afccc50a46ea5955d321a9a68003335"}, + {file = "torch-2.3.1+cpu-cp311-cp311-linux_x86_64.whl", hash = "sha256:a272defe305dbd944aa28a91cc3db0f0149495b3ebec2e39723a7224fa05dc57"}, + {file = "torch-2.3.1+cpu-cp311-cp311-win_amd64.whl", hash = "sha256:d2965eb54d3c8818e2280a54bd53e8246a6bb34e4b10bd19c59f35b611dd9f05"}, + {file = "torch-2.3.1+cpu-cp312-cp312-linux_x86_64.whl", hash = "sha256:2141a6cb7021adf2f92a0fd372cfeac524ba460bd39ce3a641d30a561e41f69a"}, + {file = "torch-2.3.1+cpu-cp312-cp312-win_amd64.whl", hash = "sha256:6acdca2530462611095c44fd95af75ecd5b9646eac813452fe0adf31a9bc310a"}, + {file = "torch-2.3.1+cpu-cp38-cp38-linux_x86_64.whl", hash = "sha256:cab92d5101e6db686c5525e04d87cedbcf3a556073d71d07fbe7d1ce09630ffb"}, + {file = "torch-2.3.1+cpu-cp38-cp38-win_amd64.whl", hash = "sha256:dbc784569a367fd425158cf4ae82057dd3011185ba5fc68440432ba0562cb5b2"}, + {file = "torch-2.3.1+cpu-cp39-cp39-linux_x86_64.whl", hash = "sha256:a3cb8e61ba311cee1bb7463cbdcf3ebdfd071e2091e74c5785e3687eb02819f9"}, + {file = "torch-2.3.1+cpu-cp39-cp39-win_amd64.whl", hash = "sha256:df68668056e62c0332e03f43d9da5d4278b39df1ba58d30ec20d34242070955d"}, ] [package.dependencies] @@ -1940,25 +1788,18 @@ fsspec = "*" jinja2 = "*" mkl = {version = ">=2021.1.1,<=2021.4.0", markers = "platform_system == \"Windows\""} networkx = "*" -nvidia-cublas-cu12 = {version = "12.1.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cuda-cupti-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cuda-nvrtc-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cuda-runtime-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cudnn-cu12 = {version = "8.9.2.26", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cufft-cu12 = {version = "11.0.2.54", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-curand-cu12 = {version = "10.3.2.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cusolver-cu12 = {version = "11.4.5.107", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cusparse-cu12 = {version = "12.1.0.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-nccl-cu12 = {version = "2.20.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} sympy = "*" -triton = {version = "2.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\""} typing-extensions = ">=4.8.0" [package.extras] opt-einsum = ["opt-einsum (>=3.3)"] optree = ["optree (>=0.9.1)"] +[package.source] +type = "legacy" +url = "https://download.pytorch.org/whl/cpu" +reference = "torch_cpu" + [[package]] name = "traitlets" version = "5.14.3" @@ -1974,29 +1815,6 @@ files = [ docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<8.2)", "pytest-mock", "pytest-mypy-testing"] -[[package]] -name = "triton" -version = "2.3.1" -description = "A language and compiler for custom Deep Learning operations" -optional = true -python-versions = "*" -files = [ - {file = "triton-2.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c84595cbe5e546b1b290d2a58b1494df5a2ef066dd890655e5b8a8a92205c33"}, - {file = "triton-2.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9d64ae33bcb3a7a18081e3a746e8cf87ca8623ca13d2c362413ce7a486f893e"}, - {file = "triton-2.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaf80e8761a9e3498aa92e7bf83a085b31959c61f5e8ac14eedd018df6fccd10"}, - {file = "triton-2.3.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b13bf35a2b659af7159bf78e92798dc62d877aa991de723937329e2d382f1991"}, - {file = "triton-2.3.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63381e35ded3304704ea867ffde3b7cfc42c16a55b3062d41e017ef510433d66"}, - {file = "triton-2.3.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d968264523c7a07911c8fb51b4e0d1b920204dae71491b1fe7b01b62a31e124"}, -] - -[package.dependencies] -filelock = "*" - -[package.extras] -build = ["cmake (>=3.20)", "lit"] -tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)", "torch"] -tutorials = ["matplotlib", "pandas", "tabulate", "torch"] - [[package]] name = "typeguard" version = "4.3.0" @@ -2189,4 +2007,4 @@ zanj = ["zanj"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "2ce686a4013fe36f7ba4776adf40fe201c7c7855f92bb9ce1487b8372a3ae9df" +content-hash = "c02e2099cb47f3f6258d352e652c0b1d3dca91d52177f907dc4b17235942f134" diff --git a/pyproject.toml b/pyproject.toml index ac0d21bb..8d8373df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ repository = "https://github.com/mivanit/muutils" [tool.poetry.dependencies] python = "^3.8" numpy = { version = "^1.22.4", optional = true } -torch = { version = ">=1.13.1", optional = true } +torch = { version = ">=1.13.1", optional = true, source = "torch_cpu" } jaxtyping = { version = "^0.2.12", optional = true } ipython = { version = "^8.20.0", optional = true, python = "^3.10" } zanj = { version = "^0.2.2", optional = true, python = "^3.10" } From d31edc94109471ebc00dd8e21f005abec6d4ee70 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 16:45:43 -0700 Subject: [PATCH 061/158] minor fix to sysinfo --- muutils/sysinfo.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/muutils/sysinfo.py b/muutils/sysinfo.py index a37467bf..831783b5 100644 --- a/muutils/sysinfo.py +++ b/muutils/sysinfo.py @@ -146,7 +146,9 @@ def platform() -> dict: def git_info(with_log: bool = False) -> dict: git_version: dict = _popen(["git", "version"]) git_status: dict = _popen(["git", "status"]) - if git_status["stderr"].startswith("fatal: not a git repository"): + if not git_status["stderr"] or git_status["stderr"].startswith( + "fatal: not a git repository" + ): return { "git version": git_version["stdout"], "git status": git_status, From 2f4284523f9f0e8527782951dd14ea4c4e55eea0 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 16:46:13 -0700 Subject: [PATCH 062/158] makefile and CI changes - `RUN_GLOBAL=1` now makes `PYTHON` just `python` instead of `poetry run python` - `RUN_GLBOAL=1` now used in CI - `uv` now used for faster installs in CI - separate job for checking consistency of `poetry.lock` and `dev-requirements.txt` --- .github/workflows/checks.yml | 50 ++++++++++++++++++++++++++---------- makefile | 48 ++++++++++++++++++++++------------ 2 files changed, 68 insertions(+), 30 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index d484b13d..d4740056 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v4 with: fetch-depth: 0 @@ -22,7 +22,31 @@ jobs: run: pip install pycln isort black - name: Run Format Checks - run: make check-format + run: make check-format RUN_GLOBAL=1 + + check-deps: + name: Check poetry.lock and dev dependencies + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Install Poetry + run: | + curl -sSL https://install.python-poetry.org | python3 - + + - name: Check poetry.lock + run: poetry lock --check + + - name: Check dev-requirements.txt + run: make check-dev-dep test: name: Test and Lint @@ -44,7 +68,7 @@ jobs: # torch: '2.3.1' steps: - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v4 with: fetch-depth: 0 @@ -53,24 +77,22 @@ jobs: with: python-version: ${{ matrix.versions.python }} + - name: Install uv + run: pip install uv + - name: Install dependencies - run: | - curl -sSL https://install.python-poetry.org | python3 - - poetry lock --check - export CUDA_VISIBLE_DEVICES=0 - poetry add torch@${{ matrix.versions.torch }}+cpu --source torch_cpu - poetry install --all-extras + run: uv install -r dev-requirements.txt --system - name: tests - run: make test + run: make test RUN_GLOBAL=1 - - name: tests - strict mode - run: make test WARN_STRICT=1 + - name: tests in strict mode + run: make test WARN_STRICT=1 RUN_GLOBAL=1 - name: check typing if: ${{ matrix.versions.python != '3.8' && matrix.versions.python != '3.9' }} - run: make typing + run: make typing RUN_GLOBAL=1 - name: check typing in compatibility mode if: ${{ matrix.versions.python == '3.8' || matrix.versions.python == '3.9' }} - run: make typing-compat \ No newline at end of file + run: make typing-compat RUN_GLOBAL=1 \ No newline at end of file diff --git a/makefile b/makefile index 9f9a6e40..bb8cfeab 100644 --- a/makefile +++ b/makefile @@ -9,7 +9,7 @@ PYPROJECT := pyproject.toml VERSION := $(shell python -c "import re; print(re.search(r'^version\s*=\s*\"(.+?)\"', open('$(PYPROJECT)').read(), re.MULTILINE).group(1))") LAST_VERSION := $(shell cat $(LAST_VERSION_FILE)) -PYPOETRY := poetry run python +PYTHON_BASE := python # note that the commands at the end: # 1) format the git log @@ -37,21 +37,31 @@ version: exit 1; \ fi +# command line options +# -------------------------------------------------- +# for formatting or CI, we might want to run python without setting up all of poetry +RUN_GLOBAL ?= 0 +ifeq ($(RUN_GLOBAL),0) + PYTHON = poetry run $(PYTHON_BASE) +else + PYTHON = $(PYTHON_BASE) +endif + # formatting # -------------------------------------------------- .PHONY: format format: - python -m pycln --config $(PYPROJECT) --all . - python -m isort format . - python -m black . + $(PYTHON) -m pycln --config $(PYPROJECT) --all . + $(PYTHON) -m isort format . + $(PYTHON) -m black . .PHONY: check-format check-format: @echo "run format check" - python -m pycln --check --config $(PYPROJECT) . - python -m isort --check-only . - python -m black --check . + $(PYTHON) -m pycln --check --config $(PYPROJECT) . + $(PYTHON) -m isort --check-only . + $(PYTHON) -m black --check . # pytest options and coverage # -------------------------------------------------- @@ -76,9 +86,9 @@ endif .PHONY: cov cov: @echo "generate coverage reports" - $(PYPOETRY) -m coverage report -m > $(COVERAGE_REPORTS_DIR)/coverage.txt - $(PYPOETRY) -m coverage_badge -f -o $(COVERAGE_REPORTS_DIR)/coverage.svg - $(PYPOETRY) -m coverage html + $(PYTHON) -m coverage report -m > $(COVERAGE_REPORTS_DIR)/coverage.txt + $(PYTHON) -m coverage_badge -f -o $(COVERAGE_REPORTS_DIR)/coverage.svg + $(PYTHON) -m coverage html # tests # -------------------------------------------------- @@ -91,19 +101,19 @@ cov: .PHONY: typing typing: clean @echo "running type checks" - $(PYPOETRY) -m mypy --config-file $(PYPROJECT) $(PACKAGE_NAME)/ - $(PYPOETRY) -m mypy --config-file $(PYPROJECT) tests/ + $(PYTHON) -m mypy --config-file $(PYPROJECT) $(PACKAGE_NAME)/ + $(PYTHON) -m mypy --config-file $(PYPROJECT) tests/ .PHONY: typing-compat typing-compat: clean @echo "running type checks in compatibility mode for older python versions" - $(PYPOETRY) -m mypy --config-file $(PYPROJECT) $(TYPECHECK_COMPAT_ARGS) $(PACKAGE_NAME)/ - $(PYPOETRY) -m mypy --config-file $(PYPROJECT) $(TYPECHECK_COMPAT_ARGS) tests/ + $(PYTHON) -m mypy --config-file $(PYPROJECT) $(TYPECHECK_COMPAT_ARGS) $(PACKAGE_NAME)/ + $(PYTHON) -m mypy --config-file $(PYPROJECT) $(TYPECHECK_COMPAT_ARGS) tests/ .PHONY: test test: clean @echo "running tests" - $(PYPOETRY) -m pytest $(PYTEST_OPTIONS) $(TESTS_DIR) + $(PYTHON) -m pytest $(PYTEST_OPTIONS) $(TESTS_DIR) .PHONY: check @@ -188,4 +198,10 @@ clean: help: @echo -n "# list make targets" @echo ":" - @cat Makefile | sed -n '/^\.PHONY: / h; /\(^\t@*echo\|^\t:\)/ {H; x; /PHONY/ s/.PHONY: \(.*\)\n.*"\(.*\)"/ make \1\t\2/p; d; x}'| sort -k2,2 |expand -t 25 \ No newline at end of file + @cat Makefile | sed -n '/^\.PHONY: / h; /\(^\t@*echo\|^\t:\)/ {H; x; /PHONY/ s/.PHONY: \(.*\)\n.*"\(.*\)"/ make \1\t\2/p; d; x}'| sort -k2,2 |expand -t 25 + @echo "# makefile variables:" + @echo " PYTHON = $(PYTHON)" + @echo " PACKAGE_NAME = $(PACKAGE_NAME)" + @echo " VERSION = $(VERSION)" + @echo " LAST_VERSION = $(LAST_VERSION)" + @echo " PYTEST_OPTIONS = $(PYTEST_OPTIONS)" \ No newline at end of file From b9d5b7c5a8dc4e04d4aaa0091e421003d09fdf94 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 16:50:21 -0700 Subject: [PATCH 063/158] check-dep-dev will now also check poetry.lock --- .github/workflows/checks.yml | 5 +---- makefile | 2 ++ 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index d4740056..389ec395 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -42,10 +42,7 @@ jobs: run: | curl -sSL https://install.python-poetry.org | python3 - - - name: Check poetry.lock - run: poetry lock --check - - - name: Check dev-requirements.txt + - name: Check poetry.lock and dev-requirements.txt run: make check-dev-dep test: diff --git a/makefile b/makefile index bb8cfeab..cc8f9726 100644 --- a/makefile +++ b/makefile @@ -141,11 +141,13 @@ EXPORT_ARGS := -E zanj -E array -E notebook --with dev --without-hashes .PHONY: dep-dev dep-dev: @echo "exporting dev and extras dependencies to dev-requirements.txt" + poetry update poetry export $(EXPORT_ARGS) --output dev-requirements.txt .PHONY: check-dep-dev check-dep-dev: @echo "checking requirements.txt matches poetry dependencies" + poetry lock --check poetry export $(EXPORT_ARGS) | diff - dev-requirements.txt .PHONY: build From d152ad0467cec217216ecf9dd96ea9dad65093ad Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 16:52:10 -0700 Subject: [PATCH 064/158] fix typo in CI --- .github/workflows/checks.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 389ec395..44318476 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -25,7 +25,7 @@ jobs: run: make check-format RUN_GLOBAL=1 check-deps: - name: Check poetry.lock and dev dependencies + name: Check dependencies runs-on: ubuntu-latest steps: - name: Checkout code @@ -43,7 +43,7 @@ jobs: curl -sSL https://install.python-poetry.org | python3 - - name: Check poetry.lock and dev-requirements.txt - run: make check-dev-dep + run: make check-dep-dev test: name: Test and Lint From 62778d9004c78859350b95b722c7dcba31bfe85a Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 16:52:32 -0700 Subject: [PATCH 065/158] fix another CI typo --- .github/workflows/checks.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 44318476..763e733b 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -78,7 +78,7 @@ jobs: run: pip install uv - name: Install dependencies - run: uv install -r dev-requirements.txt --system + run: uv pip install -r dev-requirements.txt --system - name: tests run: make test RUN_GLOBAL=1 From f6366b0f5ffab54e7ed25911f2dbcc30b3e71a7b Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 16:56:58 -0700 Subject: [PATCH 066/158] actions wip --- .github/workflows/checks.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 763e733b..973885fa 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -41,13 +41,20 @@ jobs: - name: Install Poetry run: | curl -sSL https://install.python-poetry.org | python3 - + $POETRY_HOME/bin/pip install --user poetry-plugin-export - name: Check poetry.lock and dev-requirements.txt run: make check-dep-dev + - name: Try installing all dependencies + run: | + pip install uv + uv pip install -r dev-requirements.txt --system + test: name: Test and Lint runs-on: ubuntu-latest + needs: [lint, check-deps] strategy: matrix: versions: From 9875bdefd5d28cf1533d6218f37d947a3f4af1cd Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 17:00:25 -0700 Subject: [PATCH 067/158] wip --- .github/workflows/checks.yml | 2 +- makefile | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 973885fa..6bd079f9 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -41,7 +41,7 @@ jobs: - name: Install Poetry run: | curl -sSL https://install.python-poetry.org | python3 - - $POETRY_HOME/bin/pip install --user poetry-plugin-export + poetry self add poetry-plugin-export - name: Check poetry.lock and dev-requirements.txt run: make check-dep-dev diff --git a/makefile b/makefile index cc8f9726..4472fce7 100644 --- a/makefile +++ b/makefile @@ -147,7 +147,7 @@ dep-dev: .PHONY: check-dep-dev check-dep-dev: @echo "checking requirements.txt matches poetry dependencies" - poetry lock --check + poetry check --lock poetry export $(EXPORT_ARGS) | diff - dev-requirements.txt .PHONY: build From 2ad3d616feab71d7054307f1cb40d1c6d312f778 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 17:03:08 -0700 Subject: [PATCH 068/158] wip --- .github/workflows/checks.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 6bd079f9..127839ce 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -41,7 +41,8 @@ jobs: - name: Install Poetry run: | curl -sSL https://install.python-poetry.org | python3 - - poetry self add poetry-plugin-export + poetry self add poetry-plugin-export + poetry self show plugins - name: Check poetry.lock and dev-requirements.txt run: make check-dep-dev From 0c98e8200f0148f4ae9d86564d5278a16138badf Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 17:05:01 -0700 Subject: [PATCH 069/158] wip --- .github/workflows/checks.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 127839ce..a1301a3c 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -39,8 +39,10 @@ jobs: python-version: '3.10' - name: Install Poetry + run: curl -sSL https://install.python-poetry.org | python3 - + + - name: Poetry Plugins run: | - curl -sSL https://install.python-poetry.org | python3 - poetry self add poetry-plugin-export poetry self show plugins From cd344115bd2ab2934ef3a128df3e140134a585a6 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 17:15:39 -0700 Subject: [PATCH 070/158] updated poetry locally --- dev-requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 1a9326ad..83807f54 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -67,7 +67,7 @@ stack-data==0.6.3 ; python_version >= "3.10" and python_version < "4.0" sympy==1.12.1 ; python_version >= "3.8" and python_version < "4.0" tbb==2021.12.0 ; python_version >= "3.8" and python_version < "4.0" and platform_system == "Windows" tenacity==8.4.1 ; python_version >= "3.8" and python_version < "4.0" -tomli==2.0.1 ; python_version >= "3.8" and python_version < "3.11" +tomli==2.0.1 ; python_version >= "3.8" and python_full_version <= "3.11.0a6" tomlkit==0.12.5 ; python_version >= "3.8" and python_version < "4.0" torch==2.3.1+cpu ; python_version >= "3.8" and python_version < "4.0" traitlets==5.14.3 ; python_version >= "3.10" and python_version < "4.0" From 611735f8a03b7ad03f8c18cc725217ca4403a2e7 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 17:23:25 -0700 Subject: [PATCH 071/158] wip --- .github/workflows/checks.yml | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index a1301a3c..adede11c 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -49,10 +49,11 @@ jobs: - name: Check poetry.lock and dev-requirements.txt run: make check-dep-dev - - name: Try installing all dependencies - run: | - pip install uv - uv pip install -r dev-requirements.txt --system + - name: Install uv + run: pip install uv + + - name: Install dependencies + run: uv pip install -r dev-requirements.txt --system test: name: Test and Lint From be4a8facc247247794d846191ab36d688c84ea40 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 17:29:36 -0700 Subject: [PATCH 072/158] wip --- .github/workflows/checks.yml | 2 +- makefile | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index adede11c..5115d723 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -53,7 +53,7 @@ jobs: run: pip install uv - name: Install dependencies - run: uv pip install -r dev-requirements.txt --system + run: uv pip install -r dev-requirements.txt --system --no-deps # we already should have all dependencies exported into dev-requirements.txt test: name: Test and Lint diff --git a/makefile b/makefile index 4472fce7..ea64af22 100644 --- a/makefile +++ b/makefile @@ -136,7 +136,7 @@ verify-git: fi; \ -EXPORT_ARGS := -E zanj -E array -E notebook --with dev --without-hashes +EXPORT_ARGS := --all-extras --with dev --without-hashes .PHONY: dep-dev dep-dev: From 6c55721ed4bb86b85968814dd6a2ae786fd97eae Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 17:33:25 -0700 Subject: [PATCH 073/158] testing --- .github/workflows/checks.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 5115d723..5af3e4ed 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -86,7 +86,9 @@ jobs: python-version: ${{ matrix.versions.python }} - name: Install uv - run: pip install uv + run: | + pip install uv + uv pip install filelock - name: Install dependencies run: uv pip install -r dev-requirements.txt --system From 5aada446865b223ed7beff368fcf2e2efbfd5acc Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 17:35:36 -0700 Subject: [PATCH 074/158] more testing --- .github/workflows/checks.yml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 5af3e4ed..318354a4 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -86,9 +86,10 @@ jobs: python-version: ${{ matrix.versions.python }} - name: Install uv - run: | - pip install uv - uv pip install filelock + run: pip install uv + + - name: Install filelock + run: uv pip install filelock - name: Install dependencies run: uv pip install -r dev-requirements.txt --system From 2c7c0773975395def530571a5785410dc0e20084 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 17:37:08 -0700 Subject: [PATCH 075/158] testinggg --- .github/workflows/checks.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 318354a4..bc2de570 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -51,6 +51,9 @@ jobs: - name: Install uv run: pip install uv + + - name: Install filelock + run: uv pip install filelock - name: Install dependencies run: uv pip install -r dev-requirements.txt --system --no-deps # we already should have all dependencies exported into dev-requirements.txt @@ -88,9 +91,6 @@ jobs: - name: Install uv run: pip install uv - - name: Install filelock - run: uv pip install filelock - - name: Install dependencies run: uv pip install -r dev-requirements.txt --system From 8a47ebf6e17a4ae8619954570be0810394aea766 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 17:39:12 -0700 Subject: [PATCH 076/158] more testing, filelock still broken --- .github/workflows/checks.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index bc2de570..bfc0c541 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -53,7 +53,7 @@ jobs: run: pip install uv - name: Install filelock - run: uv pip install filelock + run: uv pip install filelock --system - name: Install dependencies run: uv pip install -r dev-requirements.txt --system --no-deps # we already should have all dependencies exported into dev-requirements.txt From 87b15565157f645a702d038eeba6f109f513ac96 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 17:59:45 -0700 Subject: [PATCH 077/158] no zanj dep --- dev-requirements.txt | 1 - makefile | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 83807f54..6f2aaf7c 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -77,5 +77,4 @@ typing-extensions==4.12.2 ; python_version >= "3.8" and python_version < "4.0" typing-inspect==0.9.0 ; python_version >= "3.8" and python_version < "4" wcwidth==0.2.13 ; python_version >= "3.10" and python_version < "4.0" wrapt==1.16.0 ; python_version >= "3.8" and python_version < "4.0" -zanj==0.2.2 ; python_version >= "3.10" and python_version < "4.0" zipp==3.19.2 ; python_version >= "3.8" and python_version < "3.10" diff --git a/makefile b/makefile index ea64af22..78b50b92 100644 --- a/makefile +++ b/makefile @@ -136,7 +136,8 @@ verify-git: fi; \ -EXPORT_ARGS := --all-extras --with dev --without-hashes +# no zanj, it gets special treatment because it depends on muutils +EXPORT_ARGS := -E array -E notebook --with dev --without-hashes .PHONY: dep-dev dep-dev: From be25a18050fe77e7bf2532c8f5b63fed213f5632 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 18:01:28 -0700 Subject: [PATCH 078/158] wip --- .github/workflows/checks.yml | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index bfc0c541..13dd7081 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -52,11 +52,11 @@ jobs: - name: Install uv run: pip install uv - - name: Install filelock - run: uv pip install filelock --system - - name: Install dependencies run: uv pip install -r dev-requirements.txt --system --no-deps # we already should have all dependencies exported into dev-requirements.txt + + - name: Install zanj (requires muutils) + run: uv pip install zanj --system test: name: Test and Lint @@ -92,7 +92,10 @@ jobs: run: pip install uv - name: Install dependencies - run: uv pip install -r dev-requirements.txt --system + run: | + uv pip install -r dev-requirements.txt --system + uv pip install torch==${{ matrix.versions.torch}} --system + uv pip install zanj --system - name: tests run: make test RUN_GLOBAL=1 From 77ca48fe859abb62a3c5eeb8df62032a9e6b2646 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 18:03:10 -0700 Subject: [PATCH 079/158] wip --- .github/workflows/checks.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 13dd7081..eb12bfff 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -51,6 +51,9 @@ jobs: - name: Install uv run: pip install uv + + - name: Install problematic dependencies + run: uv pip install filelock ffspec --system - name: Install dependencies run: uv pip install -r dev-requirements.txt --system --no-deps # we already should have all dependencies exported into dev-requirements.txt @@ -94,7 +97,7 @@ jobs: - name: Install dependencies run: | uv pip install -r dev-requirements.txt --system - uv pip install torch==${{ matrix.versions.torch}} --system + uv pip install torch==${{ matrix.versions.torch}}+cpu --system --extra-index-url https://download.pytorch.org/whl/cpu uv pip install zanj --system - name: tests From 6b294aa14f3359e7b5be28bed08b8f6138f0c110 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 18:07:42 -0700 Subject: [PATCH 080/158] wip --- .github/workflows/checks.yml | 2 +- poetry.lock | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index eb12bfff..24a8c229 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -53,7 +53,7 @@ jobs: run: pip install uv - name: Install problematic dependencies - run: uv pip install filelock ffspec --system + run: uv pip install filelock fsspec --system - name: Install dependencies run: uv pip install -r dev-requirements.txt --system --no-deps # we already should have all dependencies exported into dev-requirements.txt diff --git a/poetry.lock b/poetry.lock index 61b71fc1..4ba88a98 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "astroid" From 61b6d82727f3ad77b84ac9b58de29e4953bdd778 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 18:30:50 -0700 Subject: [PATCH 081/158] test --- .github/workflows/checks.yml | 12 ++++++------ dev-requirements.txt | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 24a8c229..fd239776 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -41,13 +41,13 @@ jobs: - name: Install Poetry run: curl -sSL https://install.python-poetry.org | python3 - - - name: Poetry Plugins - run: | - poetry self add poetry-plugin-export - poetry self show plugins + # - name: Poetry Plugins + # run: | + # poetry self add poetry-plugin-export + # poetry self show plugins - - name: Check poetry.lock and dev-requirements.txt - run: make check-dep-dev + # - name: Check poetry.lock and dev-requirements.txt + # run: make check-dep-dev - name: Install uv run: pip install uv diff --git a/dev-requirements.txt b/dev-requirements.txt index 6f2aaf7c..1a002d27 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,4 +1,4 @@ ---extra-index-url https://download.pytorch.org/whl/cpu +# --extra-index-url https://download.pytorch.org/whl/cpu astroid==2.15.8 ; python_version >= "3.8" and python_version < "4.0" asttokens==2.4.1 ; python_version >= "3.10" and python_version < "4.0" From 3c9c15dbf9b63b6091f1afbd05ae09fd698dea5f Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 18:32:03 -0700 Subject: [PATCH 082/158] test --- dev-requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 1a002d27..506ae3bd 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -69,7 +69,7 @@ tbb==2021.12.0 ; python_version >= "3.8" and python_version < "4.0" and platform tenacity==8.4.1 ; python_version >= "3.8" and python_version < "4.0" tomli==2.0.1 ; python_version >= "3.8" and python_full_version <= "3.11.0a6" tomlkit==0.12.5 ; python_version >= "3.8" and python_version < "4.0" -torch==2.3.1+cpu ; python_version >= "3.8" and python_version < "4.0" +torch==2.3.1 ; python_version >= "3.8" and python_version < "4.0" traitlets==5.14.3 ; python_version >= "3.10" and python_version < "4.0" typeguard==4.3.0 ; python_version >= "3.8" and python_version < "4.0" typer==0.12.3 ; python_version >= "3.8" and python_version < "4" From d73f7e74e8aa94dd13bed43fb4c56dd50b12b38c Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 18:35:55 -0700 Subject: [PATCH 083/158] test --- dev-requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 506ae3bd..859d2aa8 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,4 +1,5 @@ -# --extra-index-url https://download.pytorch.org/whl/cpu +--extra-index-url https://download.pytorch.org/whl/cpu +--extra-index-url https://pypi.org/ astroid==2.15.8 ; python_version >= "3.8" and python_version < "4.0" asttokens==2.4.1 ; python_version >= "3.10" and python_version < "4.0" From 0b14bb318a5ecadafe9945731b30195f38c295b1 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 18 Jun 2024 18:38:49 -0700 Subject: [PATCH 084/158] pls work --- .github/workflows/checks.yml | 15 +-- dev-requirements.txt | 16 ++- makefile | 3 +- poetry.lock | 216 ++++++++++++++++++++++++++++++++--- pyproject.toml | 10 +- 5 files changed, 225 insertions(+), 35 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index fd239776..80b2c392 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -41,19 +41,16 @@ jobs: - name: Install Poetry run: curl -sSL https://install.python-poetry.org | python3 - - # - name: Poetry Plugins - # run: | - # poetry self add poetry-plugin-export - # poetry self show plugins + - name: Poetry Plugins + run: | + poetry self add poetry-plugin-export + poetry self show plugins - # - name: Check poetry.lock and dev-requirements.txt - # run: make check-dep-dev + - name: Check poetry.lock and dev-requirements.txt + run: make check-dep-dev - name: Install uv run: pip install uv - - - name: Install problematic dependencies - run: uv pip install filelock fsspec --system - name: Install dependencies run: uv pip install -r dev-requirements.txt --system --no-deps # we already should have all dependencies exported into dev-requirements.txt diff --git a/dev-requirements.txt b/dev-requirements.txt index 859d2aa8..6bd7f94c 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,6 +1,3 @@ ---extra-index-url https://download.pytorch.org/whl/cpu ---extra-index-url https://pypi.org/ - astroid==2.15.8 ; python_version >= "3.8" and python_version < "4.0" asttokens==2.4.1 ; python_version >= "3.10" and python_version < "4.0" black==24.4.2 ; python_version >= "3.8" and python_version < "4.0" @@ -42,6 +39,18 @@ mypy-extensions==1.0.0 ; python_version >= "3.8" and python_version < "4.0" mypy==1.10.0 ; python_version >= "3.8" and python_version < "4.0" networkx==3.1 ; python_version >= "3.8" and python_version < "4.0" numpy==1.24.4 ; python_version >= "3.8" and python_version < "4.0" +nvidia-cublas-cu12==12.1.3.1 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "4.0" +nvidia-cuda-cupti-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "4.0" +nvidia-cuda-nvrtc-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "4.0" +nvidia-cuda-runtime-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "4.0" +nvidia-cudnn-cu12==8.9.2.26 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "4.0" +nvidia-cufft-cu12==11.0.2.54 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "4.0" +nvidia-curand-cu12==10.3.2.106 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "4.0" +nvidia-cusolver-cu12==11.4.5.107 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "4.0" +nvidia-cusparse-cu12==12.1.0.106 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "4.0" +nvidia-nccl-cu12==2.20.5 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "4.0" +nvidia-nvjitlink-cu12==12.5.40 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "4.0" +nvidia-nvtx-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "4.0" packaging==24.1 ; python_version >= "3.8" and python_version < "4.0" parso==0.8.4 ; python_version >= "3.10" and python_version < "4.0" pathspec==0.12.1 ; python_version >= "3.8" and python_version < "4.0" @@ -72,6 +81,7 @@ tomli==2.0.1 ; python_version >= "3.8" and python_full_version <= "3.11.0a6" tomlkit==0.12.5 ; python_version >= "3.8" and python_version < "4.0" torch==2.3.1 ; python_version >= "3.8" and python_version < "4.0" traitlets==5.14.3 ; python_version >= "3.10" and python_version < "4.0" +triton==2.3.1 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.12" and python_version >= "3.8" typeguard==4.3.0 ; python_version >= "3.8" and python_version < "4.0" typer==0.12.3 ; python_version >= "3.8" and python_version < "4" typing-extensions==4.12.2 ; python_version >= "3.8" and python_version < "4.0" diff --git a/makefile b/makefile index 78b50b92..b8f9dda2 100644 --- a/makefile +++ b/makefile @@ -137,7 +137,8 @@ verify-git: # no zanj, it gets special treatment because it depends on muutils -EXPORT_ARGS := -E array -E notebook --with dev --without-hashes +# without urls since pytorch extra index breaks things? +EXPORT_ARGS := -E array -E notebook --with dev --without-hashes --without-urls .PHONY: dep-dev dep-dev: diff --git a/poetry.lock b/poetry.lock index 4ba88a98..98a6a45b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1229,6 +1229,148 @@ files = [ {file = "numpy-1.24.4.tar.gz", hash = "sha256:80f5e3a4e498641401868df4208b74581206afbee7cf7b8329daae82676d9463"}, ] +[[package]] +name = "nvidia-cublas-cu12" +version = "12.1.3.1" +description = "CUBLAS native runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728"}, + {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906"}, +] + +[[package]] +name = "nvidia-cuda-cupti-cu12" +version = "12.1.105" +description = "CUDA profiling tools runtime libs." +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e"}, + {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4"}, +] + +[[package]] +name = "nvidia-cuda-nvrtc-cu12" +version = "12.1.105" +description = "NVRTC native runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2"}, + {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed"}, +] + +[[package]] +name = "nvidia-cuda-runtime-cu12" +version = "12.1.105" +description = "CUDA Runtime native Libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40"}, + {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344"}, +] + +[[package]] +name = "nvidia-cudnn-cu12" +version = "8.9.2.26" +description = "cuDNN runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl", hash = "sha256:5ccb288774fdfb07a7e7025ffec286971c06d8d7b4fb162525334616d7629ff9"}, +] + +[package.dependencies] +nvidia-cublas-cu12 = "*" + +[[package]] +name = "nvidia-cufft-cu12" +version = "11.0.2.54" +description = "CUFFT native runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56"}, + {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253"}, +] + +[[package]] +name = "nvidia-curand-cu12" +version = "10.3.2.106" +description = "CURAND native runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0"}, + {file = "nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a"}, +] + +[[package]] +name = "nvidia-cusolver-cu12" +version = "11.4.5.107" +description = "CUDA solver native runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd"}, + {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5"}, +] + +[package.dependencies] +nvidia-cublas-cu12 = "*" +nvidia-cusparse-cu12 = "*" +nvidia-nvjitlink-cu12 = "*" + +[[package]] +name = "nvidia-cusparse-cu12" +version = "12.1.0.106" +description = "CUSPARSE native runtime libraries" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c"}, + {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a"}, +] + +[package.dependencies] +nvidia-nvjitlink-cu12 = "*" + +[[package]] +name = "nvidia-nccl-cu12" +version = "2.20.5" +description = "NVIDIA Collective Communication Library (NCCL) Runtime" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01"}, + {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:057f6bf9685f75215d0c53bf3ac4a10b3e6578351de307abad9e18a99182af56"}, +] + +[[package]] +name = "nvidia-nvjitlink-cu12" +version = "12.5.40" +description = "Nvidia JIT LTO Library" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_x86_64.whl", hash = "sha256:d9714f27c1d0f0895cd8915c07a87a1d0029a0aa36acaf9156952ec2a8a12189"}, + {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-win_amd64.whl", hash = "sha256:c3401dc8543b52d3a8158007a0c1ab4e9c768fcbd24153a48c86972102197ddd"}, +] + +[[package]] +name = "nvidia-nvtx-cu12" +version = "12.1.105" +description = "NVIDIA Tools Extension" +optional = true +python-versions = ">=3" +files = [ + {file = "nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5"}, + {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"}, +] + [[package]] name = "packaging" version = "24.1" @@ -1765,21 +1907,31 @@ files = [ [[package]] name = "torch" -version = "2.3.1+cpu" +version = "2.3.1" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" optional = true python-versions = ">=3.8.0" files = [ - {file = "torch-2.3.1+cpu-cp310-cp310-linux_x86_64.whl", hash = "sha256:d679e21d871982b9234444331a26350902cfd2d5ca44ce6f49896af8b3a3087d"}, - {file = "torch-2.3.1+cpu-cp310-cp310-win_amd64.whl", hash = "sha256:500bf790afc2fd374a15d06213242e517afccc50a46ea5955d321a9a68003335"}, - {file = "torch-2.3.1+cpu-cp311-cp311-linux_x86_64.whl", hash = "sha256:a272defe305dbd944aa28a91cc3db0f0149495b3ebec2e39723a7224fa05dc57"}, - {file = "torch-2.3.1+cpu-cp311-cp311-win_amd64.whl", hash = "sha256:d2965eb54d3c8818e2280a54bd53e8246a6bb34e4b10bd19c59f35b611dd9f05"}, - {file = "torch-2.3.1+cpu-cp312-cp312-linux_x86_64.whl", hash = "sha256:2141a6cb7021adf2f92a0fd372cfeac524ba460bd39ce3a641d30a561e41f69a"}, - {file = "torch-2.3.1+cpu-cp312-cp312-win_amd64.whl", hash = "sha256:6acdca2530462611095c44fd95af75ecd5b9646eac813452fe0adf31a9bc310a"}, - {file = "torch-2.3.1+cpu-cp38-cp38-linux_x86_64.whl", hash = "sha256:cab92d5101e6db686c5525e04d87cedbcf3a556073d71d07fbe7d1ce09630ffb"}, - {file = "torch-2.3.1+cpu-cp38-cp38-win_amd64.whl", hash = "sha256:dbc784569a367fd425158cf4ae82057dd3011185ba5fc68440432ba0562cb5b2"}, - {file = "torch-2.3.1+cpu-cp39-cp39-linux_x86_64.whl", hash = "sha256:a3cb8e61ba311cee1bb7463cbdcf3ebdfd071e2091e74c5785e3687eb02819f9"}, - {file = "torch-2.3.1+cpu-cp39-cp39-win_amd64.whl", hash = "sha256:df68668056e62c0332e03f43d9da5d4278b39df1ba58d30ec20d34242070955d"}, + {file = "torch-2.3.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:605a25b23944be5ab7c3467e843580e1d888b8066e5aaf17ff7bf9cc30001cc3"}, + {file = "torch-2.3.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:f2357eb0965583a0954d6f9ad005bba0091f956aef879822274b1bcdb11bd308"}, + {file = "torch-2.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:32b05fe0d1ada7f69c9f86c14ff69b0ef1957a5a54199bacba63d22d8fab720b"}, + {file = "torch-2.3.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:7c09a94362778428484bcf995f6004b04952106aee0ef45ff0b4bab484f5498d"}, + {file = "torch-2.3.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:b2ec81b61bb094ea4a9dee1cd3f7b76a44555375719ad29f05c0ca8ef596ad39"}, + {file = "torch-2.3.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:490cc3d917d1fe0bd027057dfe9941dc1d6d8e3cae76140f5dd9a7e5bc7130ab"}, + {file = "torch-2.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:5802530783bd465fe66c2df99123c9a54be06da118fbd785a25ab0a88123758a"}, + {file = "torch-2.3.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:a7dd4ed388ad1f3d502bf09453d5fe596c7b121de7e0cfaca1e2017782e9bbac"}, + {file = "torch-2.3.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:a486c0b1976a118805fc7c9641d02df7afbb0c21e6b555d3bb985c9f9601b61a"}, + {file = "torch-2.3.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:224259821fe3e4c6f7edf1528e4fe4ac779c77addaa74215eb0b63a5c474d66c"}, + {file = "torch-2.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:e5fdccbf6f1334b2203a61a0e03821d5845f1421defe311dabeae2fc8fbeac2d"}, + {file = "torch-2.3.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:3c333dc2ebc189561514eda06e81df22bf8fb64e2384746b2cb9f04f96d1d4c8"}, + {file = "torch-2.3.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:07e9ba746832b8d069cacb45f312cadd8ad02b81ea527ec9766c0e7404bb3feb"}, + {file = "torch-2.3.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:462d1c07dbf6bb5d9d2f3316fee73a24f3d12cd8dacf681ad46ef6418f7f6626"}, + {file = "torch-2.3.1-cp38-cp38-win_amd64.whl", hash = "sha256:ff60bf7ce3de1d43ad3f6969983f321a31f0a45df3690921720bcad6a8596cc4"}, + {file = "torch-2.3.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:bee0bd33dc58aa8fc8a7527876e9b9a0e812ad08122054a5bff2ce5abf005b10"}, + {file = "torch-2.3.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:aaa872abde9a3d4f91580f6396d54888620f4a0b92e3976a6034759df4b961ad"}, + {file = "torch-2.3.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:3d7a7f7ef21a7520510553dc3938b0c57c116a7daee20736a9e25cbc0e832bdc"}, + {file = "torch-2.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:4777f6cefa0c2b5fa87223c213e7b6f417cf254a45e5829be4ccd1b2a4ee1011"}, + {file = "torch-2.3.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:2bb5af780c55be68fe100feb0528d2edebace1d55cb2e351de735809ba7391eb"}, ] [package.dependencies] @@ -1788,18 +1940,25 @@ fsspec = "*" jinja2 = "*" mkl = {version = ">=2021.1.1,<=2021.4.0", markers = "platform_system == \"Windows\""} networkx = "*" +nvidia-cublas-cu12 = {version = "12.1.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-cupti-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-nvrtc-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-runtime-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cudnn-cu12 = {version = "8.9.2.26", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cufft-cu12 = {version = "11.0.2.54", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-curand-cu12 = {version = "10.3.2.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusolver-cu12 = {version = "11.4.5.107", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusparse-cu12 = {version = "12.1.0.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nccl-cu12 = {version = "2.20.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} sympy = "*" +triton = {version = "2.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\""} typing-extensions = ">=4.8.0" [package.extras] opt-einsum = ["opt-einsum (>=3.3)"] optree = ["optree (>=0.9.1)"] -[package.source] -type = "legacy" -url = "https://download.pytorch.org/whl/cpu" -reference = "torch_cpu" - [[package]] name = "traitlets" version = "5.14.3" @@ -1815,6 +1974,29 @@ files = [ docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<8.2)", "pytest-mock", "pytest-mypy-testing"] +[[package]] +name = "triton" +version = "2.3.1" +description = "A language and compiler for custom Deep Learning operations" +optional = true +python-versions = "*" +files = [ + {file = "triton-2.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c84595cbe5e546b1b290d2a58b1494df5a2ef066dd890655e5b8a8a92205c33"}, + {file = "triton-2.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9d64ae33bcb3a7a18081e3a746e8cf87ca8623ca13d2c362413ce7a486f893e"}, + {file = "triton-2.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaf80e8761a9e3498aa92e7bf83a085b31959c61f5e8ac14eedd018df6fccd10"}, + {file = "triton-2.3.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b13bf35a2b659af7159bf78e92798dc62d877aa991de723937329e2d382f1991"}, + {file = "triton-2.3.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63381e35ded3304704ea867ffde3b7cfc42c16a55b3062d41e017ef510433d66"}, + {file = "triton-2.3.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d968264523c7a07911c8fb51b4e0d1b920204dae71491b1fe7b01b62a31e124"}, +] + +[package.dependencies] +filelock = "*" + +[package.extras] +build = ["cmake (>=3.20)", "lit"] +tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)", "torch"] +tutorials = ["matplotlib", "pandas", "tabulate", "torch"] + [[package]] name = "typeguard" version = "4.3.0" @@ -2007,4 +2189,4 @@ zanj = ["zanj"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "c02e2099cb47f3f6258d352e652c0b1d3dca91d52177f907dc4b17235942f134" +content-hash = "a58332e583188399d06d9346bc20aa481323d41454b57917c781d8fc72001e76" diff --git a/pyproject.toml b/pyproject.toml index 8d8373df..5fc2fa3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ repository = "https://github.com/mivanit/muutils" [tool.poetry.dependencies] python = "^3.8" numpy = { version = "^1.22.4", optional = true } -torch = { version = ">=1.13.1", optional = true, source = "torch_cpu" } +torch = { version = ">=1.13.1", optional = true } jaxtyping = { version = "^0.2.12", optional = true } ipython = { version = "^8.20.0", optional = true, python = "^3.10" } zanj = { version = "^0.2.2", optional = true, python = "^3.10" } @@ -46,10 +46,10 @@ plotly = "^5.0.0" requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" -[[tool.poetry.source]] -name = "torch_cpu" -url = "https://download.pytorch.org/whl/cpu" -priority = "explicit" +# [[tool.poetry.source]] +# name = "torch_cpu" +# url = "https://download.pytorch.org/whl/cpu" +# priority = "explicit" # TODO: make all of the following ignored across all formatting/linting # tests/input_data, tests/junk_data, muutils/_wip From fe264151c5e6bd48538f0c5e73e1ccf0433946c1 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 00:42:01 -0700 Subject: [PATCH 085/158] wip --- .github/workflows/checks.yml | 3 ++- dev-requirements.txt | 24 ------------------------ makefile | 5 +++-- poetry.lock | 3 ++- pyproject.toml | 1 + 5 files changed, 8 insertions(+), 28 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 80b2c392..8a784d47 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -92,9 +92,10 @@ jobs: run: pip install uv - name: Install dependencies + # install torch first to avoid pytorch index messing things up run: | - uv pip install -r dev-requirements.txt --system uv pip install torch==${{ matrix.versions.torch}}+cpu --system --extra-index-url https://download.pytorch.org/whl/cpu + uv pip install -r dev-requirements.txt --system --no-deps uv pip install zanj --system - name: tests diff --git a/dev-requirements.txt b/dev-requirements.txt index 6bd7f94c..6de99214 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -12,45 +12,25 @@ decorator==5.1.1 ; python_version >= "3.10" and python_version < "4.0" dill==0.3.8 ; python_version >= "3.8" and python_version < "4.0" exceptiongroup==1.2.1 ; python_version >= "3.8" and python_version < "3.11" executing==2.0.1 ; python_version >= "3.10" and python_version < "4.0" -filelock==3.15.1 ; python_version >= "3.8" and python_version < "4.0" fonttools==4.53.0 ; python_version >= "3.8" and python_version < "4.0" -fsspec==2024.6.0 ; python_version >= "3.8" and python_version < "4.0" importlib-metadata==7.1.0 ; python_version >= "3.8" and python_version < "3.10" importlib-resources==6.4.0 ; python_version >= "3.8" and python_version < "3.10" iniconfig==2.0.0 ; python_version >= "3.8" and python_version < "4.0" -intel-openmp==2021.4.0 ; python_version >= "3.8" and python_version < "4.0" and platform_system == "Windows" ipython==8.25.0 ; python_version >= "3.10" and python_version < "4.0" isort==5.13.2 ; python_version >= "3.8" and python_version < "4.0" jaxtyping==0.2.19 ; python_version >= "3.8" and python_version < "4.0" jedi==0.19.1 ; python_version >= "3.10" and python_version < "4.0" -jinja2==3.1.4 ; python_version >= "3.8" and python_version < "4.0" kiwisolver==1.4.5 ; python_version >= "3.8" and python_version < "4.0" lazy-object-proxy==1.10.0 ; python_version >= "3.8" and python_version < "4.0" libcst==1.1.0 ; python_version >= "3.8" and python_version < "4" markdown-it-py==3.0.0 ; python_version >= "3.8" and python_version < "4" -markupsafe==2.1.5 ; python_version >= "3.8" and python_version < "4.0" matplotlib-inline==0.1.7 ; python_version >= "3.10" and python_version < "4.0" matplotlib==3.7.5 ; python_version >= "3.8" and python_version < "4.0" mccabe==0.7.0 ; python_version >= "3.8" and python_version < "4.0" mdurl==0.1.2 ; python_version >= "3.8" and python_version < "4" -mkl==2021.4.0 ; python_version >= "3.8" and python_version < "4.0" and platform_system == "Windows" -mpmath==1.3.0 ; python_version >= "3.8" and python_version < "4.0" mypy-extensions==1.0.0 ; python_version >= "3.8" and python_version < "4.0" mypy==1.10.0 ; python_version >= "3.8" and python_version < "4.0" -networkx==3.1 ; python_version >= "3.8" and python_version < "4.0" numpy==1.24.4 ; python_version >= "3.8" and python_version < "4.0" -nvidia-cublas-cu12==12.1.3.1 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "4.0" -nvidia-cuda-cupti-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "4.0" -nvidia-cuda-nvrtc-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "4.0" -nvidia-cuda-runtime-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "4.0" -nvidia-cudnn-cu12==8.9.2.26 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "4.0" -nvidia-cufft-cu12==11.0.2.54 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "4.0" -nvidia-curand-cu12==10.3.2.106 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "4.0" -nvidia-cusolver-cu12==11.4.5.107 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "4.0" -nvidia-cusparse-cu12==12.1.0.106 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "4.0" -nvidia-nccl-cu12==2.20.5 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "4.0" -nvidia-nvjitlink-cu12==12.5.40 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "4.0" -nvidia-nvtx-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "4.0" packaging==24.1 ; python_version >= "3.8" and python_version < "4.0" parso==0.8.4 ; python_version >= "3.10" and python_version < "4.0" pathspec==0.12.1 ; python_version >= "3.8" and python_version < "4.0" @@ -74,14 +54,10 @@ rich==13.7.1 ; python_version >= "3.8" and python_version < "4" shellingham==1.5.4 ; python_version >= "3.8" and python_version < "4" six==1.16.0 ; python_version >= "3.8" and python_version < "4.0" stack-data==0.6.3 ; python_version >= "3.10" and python_version < "4.0" -sympy==1.12.1 ; python_version >= "3.8" and python_version < "4.0" -tbb==2021.12.0 ; python_version >= "3.8" and python_version < "4.0" and platform_system == "Windows" tenacity==8.4.1 ; python_version >= "3.8" and python_version < "4.0" tomli==2.0.1 ; python_version >= "3.8" and python_full_version <= "3.11.0a6" tomlkit==0.12.5 ; python_version >= "3.8" and python_version < "4.0" -torch==2.3.1 ; python_version >= "3.8" and python_version < "4.0" traitlets==5.14.3 ; python_version >= "3.10" and python_version < "4.0" -triton==2.3.1 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.12" and python_version >= "3.8" typeguard==4.3.0 ; python_version >= "3.8" and python_version < "4.0" typer==0.12.3 ; python_version >= "3.8" and python_version < "4" typing-extensions==4.12.2 ; python_version >= "3.8" and python_version < "4.0" diff --git a/makefile b/makefile index b8f9dda2..2d790652 100644 --- a/makefile +++ b/makefile @@ -137,8 +137,9 @@ verify-git: # no zanj, it gets special treatment because it depends on muutils -# without urls since pytorch extra index breaks things? -EXPORT_ARGS := -E array -E notebook --with dev --without-hashes --without-urls +# without urls since pytorch extra index breaks things +# no torch because we install it manually in CI +EXPORT_ARGS := -E array_no_torch -E notebook --with dev --without-hashes --without-urls .PHONY: dep-dev dep-dev: diff --git a/poetry.lock b/poetry.lock index 98a6a45b..c5d8915a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2183,10 +2183,11 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [extras] array = ["jaxtyping", "numpy", "torch"] +array-no-torch = ["jaxtyping", "numpy"] notebook = ["ipython"] zanj = ["zanj"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "a58332e583188399d06d9346bc20aa481323d41454b57917c781d8fc72001e76" +content-hash = "6992867457041d99bcd9915054e5e1d3c666f8a6eee2aba70835323fa85330ce" diff --git a/pyproject.toml b/pyproject.toml index 5fc2fa3c..ac7a31bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ zanj = { version = "^0.2.2", optional = true, python = "^3.10" } [tool.poetry.extras] array = ["numpy", "torch", "jaxtyping"] +array_no_torch = ["numpy", "jaxtyping"] notebook = ["ipython"] zanj = ["zanj"] From 68fcfd8d67e6e94d2e541e5e07b1265b0b7a6738 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 00:42:46 -0700 Subject: [PATCH 086/158] minor --- .github/workflows/checks.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 8a784d47..6b1c468f 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -94,8 +94,8 @@ jobs: - name: Install dependencies # install torch first to avoid pytorch index messing things up run: | - uv pip install torch==${{ matrix.versions.torch}}+cpu --system --extra-index-url https://download.pytorch.org/whl/cpu uv pip install -r dev-requirements.txt --system --no-deps + uv pip install torch==${{ matrix.versions.torch}}+cpu --system --extra-index-url https://download.pytorch.org/whl/cpu uv pip install zanj --system - name: tests From be9af6b561b6c2d89f64b3ae14d2a0a82f252bd1 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 00:46:37 -0700 Subject: [PATCH 087/158] wip --- .github/workflows/checks.yml | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 6b1c468f..cdcd2bbd 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -54,7 +54,10 @@ jobs: - name: Install dependencies run: uv pip install -r dev-requirements.txt --system --no-deps # we already should have all dependencies exported into dev-requirements.txt - + + - name: Install torch (special) + run: uv pip install torch==1.13.1+cpu --system --extra-index-url https://download.pytorch.org/whl/cpu + - name: Install zanj (requires muutils) run: uv pip install zanj --system @@ -96,7 +99,14 @@ jobs: run: | uv pip install -r dev-requirements.txt --system --no-deps uv pip install torch==${{ matrix.versions.torch}}+cpu --system --extra-index-url https://download.pytorch.org/whl/cpu - uv pip install zanj --system + + - name: Install muutils + run: uv pip install . --system + + - name: Install zanj + # not yet available for python 3.8 and 3.9 + if: ${{ matrix.versions.python != '3.8' && matrix.versions.python != '3.9' }} + run: uv pip install zanj --system - name: tests run: make test RUN_GLOBAL=1 From 4df7722f3a484e5bca2356d3472dff8338aab9b9 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 00:51:08 -0700 Subject: [PATCH 088/158] wip --- .github/workflows/checks.yml | 7 +++++-- pyproject.toml | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index cdcd2bbd..813e35f6 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -57,7 +57,10 @@ jobs: - name: Install torch (special) run: uv pip install torch==1.13.1+cpu --system --extra-index-url https://download.pytorch.org/whl/cpu - + + - name: Install muutils (local) + run: uv pip install . --system + - name: Install zanj (requires muutils) run: uv pip install zanj --system @@ -103,7 +106,7 @@ jobs: - name: Install muutils run: uv pip install . --system - - name: Install zanj + - name: Install zanj (>=3.10 only) # not yet available for python 3.8 and 3.9 if: ${{ matrix.versions.python != '3.8' && matrix.versions.python != '3.9' }} run: uv pip install zanj --system diff --git a/pyproject.toml b/pyproject.toml index ac7a31bb..930ef584 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "muutils" -version = "0.5.12" +version = "0.6.0" description = "A collection of miscellaneous python utilities" license = "GPL-3.0-only" authors = ["mivanit "] From 55143d809461665369493c0d6b36c977a1317205 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 01:11:23 -0700 Subject: [PATCH 089/158] gitignore --- .gitignore | 1 + test.ipynb | 448 ----------------------------------------------------- 2 files changed, 1 insertion(+), 448 deletions(-) delete mode 100644 test.ipynb diff --git a/.gitignore b/.gitignore index 0f47f10a..04d04471 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +_test.ipynb JUNK_DATA_PATH/ junk_data .pypi-token diff --git a/test.ipynb b/test.ipynb deleted file mode 100644 index 0041f15a..00000000 --- a/test.ipynb +++ /dev/null @@ -1,448 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import typing" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "x = list[str]" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "list[str]" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "types.GenericAlias" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "type(x)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "isinstance(x, typing.GenericAlias)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "y = str|int" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "isinstance(y, type)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "type" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "type(y.__args__[0])\n" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [], - "source": [ - "from typing import Iterable, Sequence, Type" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "isinstance(list, type(Iterable))" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "list[str, int]" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "list[str, int]" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "typing.Any" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "typing.Any" - ] - }, - { - "cell_type": "code", - "execution_count": 78, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "origin = args = (, )\n", - "origin = args = (, )\n", - "origin = args = (,)\n", - "origin = args = (,)\n", - "origin = args = (,)\n", - "origin = args = (,)\n" - ] - } - ], - "source": [ - "from typing import Any, Union\n", - "from types import UnionType\n", - "import types\n", - "\n", - "def _validate_type(value: Any, expected_type: Any) -> bool:\n", - " if expected_type is Any:\n", - " return True\n", - " \n", - " # base type without args\n", - " if isinstance(expected_type, type):\n", - " return isinstance(value, expected_type)\n", - "\n", - " origin: type = typing.get_origin(expected_type)\n", - " args: list = typing.get_args(expected_type)\n", - " \n", - " print(f\"{origin = } {args = }\")\n", - "\n", - " if origin is types.UnionType:\n", - " return any(_validate_type(value, arg) for arg in args)\n", - "\n", - " # generic alias, more complicated\n", - " if isinstance(expected_type, (typing.GenericAlias, typing._GenericAlias)):\n", - "\n", - " if origin is list:\n", - " assert len(args) == 1\n", - " return isinstance(value, list) and all(_validate_type(item, args[0]) for item in value)\n", - " \n", - " if origin is dict:\n", - " assert len(args) == 2\n", - " return isinstance(value, dict) and all(\n", - " _validate_type(key, args[0]) and _validate_type(val, args[1])\n", - " for key, val in value.items()\n", - " )\n", - " \n", - " if origin is set:\n", - " assert len(args) == 1\n", - " return isinstance(value, set) and all(_validate_type(item, args[0]) for item in value)\n", - " \n", - " if origin is tuple:\n", - " if len(value) != len(args):\n", - " return False\n", - " return all(_validate_type(item, arg) for item, arg in zip(value, args))\n", - " \n", - " raise ValueError(f\"Unsupported generic alias {expected_type}\")\n", - "\n", - " else:\n", - " raise ValueError(f\"Unsupported type hint {expected_type = } for {value = }\")\n", - " \n", - "assert _validate_type(1, str|int)\n", - "assert _validate_type(\"a\", str|int)\n", - "assert _validate_type([1, 2, 3], list[int])\n", - "assert not _validate_type([1, 2, 3], list[str])\n", - "assert _validate_type({\"a\", \"b\", \"c\"}, set[str])\n", - "assert not _validate_type({\"a\", \"b\", 1}, set[int])\n" - ] - }, - { - "cell_type": "code", - "execution_count": 61, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "str | int" - ] - }, - "execution_count": 61, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "y" - ] - }, - { - "cell_type": "code", - "execution_count": 55, - "metadata": {}, - "outputs": [], - "source": [ - "z = list[int]" - ] - }, - { - "cell_type": "code", - "execution_count": 79, - "metadata": {}, - "outputs": [], - "source": [ - "z1 = typing.Union[int, str]" - ] - }, - { - "cell_type": "code", - "execution_count": 82, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "typing._UnionGenericAlias" - ] - }, - "execution_count": 82, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "type(z1)" - ] - }, - { - "cell_type": "code", - "execution_count": 80, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "typing.Union[int, str]" - ] - }, - "execution_count": 80, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "z1" - ] - }, - { - "cell_type": "code", - "execution_count": 83, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "typing.Union" - ] - }, - "execution_count": 83, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "typing.get_origin(z1)" - ] - }, - { - "cell_type": "code", - "execution_count": 72, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 72, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "isinstance(z1, typing._GenericAlias)" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "metadata": {}, - "outputs": [], - "source": [ - "y_o = typing.get_origin(y)\n", - "y_a = typing.get_args(y)" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 40, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import types\n", - "types.UnionType\n", - "\n", - "y_o is types.UnionType" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.4" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} From 0a616f977fc200d9d63c1c86c79c8e5399d45bbc Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 01:12:42 -0700 Subject: [PATCH 090/158] fix torch version in dep check --- .github/workflows/checks.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 813e35f6..1ca2ca0c 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -56,7 +56,7 @@ jobs: run: uv pip install -r dev-requirements.txt --system --no-deps # we already should have all dependencies exported into dev-requirements.txt - name: Install torch (special) - run: uv pip install torch==1.13.1+cpu --system --extra-index-url https://download.pytorch.org/whl/cpu + run: uv pip install torch==2.3.1+cpu --system --extra-index-url https://download.pytorch.org/whl/cpu - name: Install muutils (local) run: uv pip install . --system From 8f7816be7d1fd11c7d8fa9eea79576bdd5edbcb6 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 01:13:58 -0700 Subject: [PATCH 091/158] version bump to 0.5.13 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 930ef584..b67477ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "muutils" -version = "0.6.0" +version = "0.5.13" description = "A collection of miscellaneous python utilities" license = "GPL-3.0-only" authors = ["mivanit "] From 41b838b39813e8dda4f8cdf4812ecc33803f12af Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 01:15:38 -0700 Subject: [PATCH 092/158] excludes in format in pyproject.toml --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b67477ba..8c1ee29f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,12 +64,12 @@ filterwarnings = [ [tool.pycln] all = true -exclude = "tests/input_data" +exclude = ["tests/input_data", "tests/junk_data"] [tool.isort] profile = "black" ignore_comments = false -extend_skip = "tests/input_data" +extend_skip = ["tests/input_data", "tests/junk_data"] [tool.black] extend-exclude = "tests/input_data" From b499c74f61ea6ba2537f0ef121ea02a0c2fe1147 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 01:15:59 -0700 Subject: [PATCH 093/158] fix for python 3.10 --- muutils/validate_type.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/muutils/validate_type.py b/muutils/validate_type.py index ed276eac..a9c31c0f 100644 --- a/muutils/validate_type.py +++ b/muutils/validate_type.py @@ -14,7 +14,11 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: # base type without args if isinstance(expected_type, type): - return isinstance(value, expected_type) + try: + # if you use args on a type like `dict[str, int]`, this will fail + return isinstance(value, expected_type) + except TypeError: + pass origin: type = typing.get_origin(expected_type) args: list = typing.get_args(expected_type) From c65b04ce6970ba1bb3db71b346e13c6484b51be0 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 01:20:18 -0700 Subject: [PATCH 094/158] add future import annotations --- muutils/jsonlines.py | 2 ++ muutils/kappa.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/muutils/jsonlines.py b/muutils/jsonlines.py index 4cae67c9..2d9fca68 100644 --- a/muutils/jsonlines.py +++ b/muutils/jsonlines.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import gzip import json from typing import Callable, Sequence diff --git a/muutils/kappa.py b/muutils/kappa.py index b53b420e..888d9c2b 100644 --- a/muutils/kappa.py +++ b/muutils/kappa.py @@ -5,6 +5,8 @@ a `lambda` is an anonymous function: kappa is the letter before lambda in the greek alphabet, hence the name of this class""" +from __future__ import annotations + from typing import Callable, Mapping, TypeVar _kappa_K = TypeVar("_kappa_K") From 69055017675d6d7ff338b97138615c3f57b6950a Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 01:31:21 -0700 Subject: [PATCH 095/158] about to do something cursed --- tests/unit/{ => test_validate}/test_validate_type.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/unit/{ => test_validate}/test_validate_type.py (100%) diff --git a/tests/unit/test_validate_type.py b/tests/unit/test_validate/test_validate_type.py similarity index 100% rename from tests/unit/test_validate_type.py rename to tests/unit/test_validate/test_validate_type.py From c281ac9b6272ddd574be8364c7e48f905d6aac12 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 01:36:41 -0700 Subject: [PATCH 096/158] switched all type hints to typing.Something --- .../unit/test_validate/test_validate_type.py | 228 +++++++++--------- 1 file changed, 114 insertions(+), 114 deletions(-) diff --git a/tests/unit/test_validate/test_validate_type.py b/tests/unit/test_validate/test_validate_type.py index 6ed340ec..c535c15c 100644 --- a/tests/unit/test_validate/test_validate_type.py +++ b/tests/unit/test_validate/test_validate_type.py @@ -1,7 +1,7 @@ from __future__ import annotations import typing -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Optional, Union import pytest @@ -20,14 +20,14 @@ (True, bool, True), (None, type(None), True), (None, int, False), - ([1, 2, 3], list, True), - ([1, 2, 3], List, True), - ({"a": 1, "b": 2}, dict, True), - ({"a": 1, "b": 2}, Dict, True), - ({1, 2, 3}, set, True), - ({1, 2, 3}, Set, True), - ((1, 2, 3), tuple, True), - ((1, 2, 3), Tuple, True), + ([1, 2, 3], typing.List, True), + ([1, 2, 3], typing.List, True), + ({"a": 1, "b": 2}, typing.Dict, True), + ({"a": 1, "b": 2}, typing.Dict, True), + ({1, 2, 3}, typing.Set, True), + ({1, 2, 3}, typing.Set, True), + ((1, 2, 3), typing.Tuple, True), + ((1, 2, 3), typing.Tuple, True), (b"bytes", bytes, True), (b"bytes", str, False), ("3.14", float, False), @@ -116,22 +116,22 @@ def test_validate_type_union(value, expected_type, expected_result): (42, Optional[int], True), ("hello", Optional[int], False), (3.14, Optional[int], False), - ([1], Optional[List[int]], True), + ([1], Optional[typing.List[int]], True), (None, Optional[int], True), (None, Optional[str], True), (None, Optional[int], True), (None, Optional[None], True), - (None, Optional[list[dict[str, int]]], True), + (None, Optional[typing.List[typing.Dict[str, int]]], True), (42, int | None, True), ("hello", int | None, False), (3.14, int | None, False), - ([1], List[int] | None, True), + ([1], typing.List[int] | None, True), (None, int | None, True), (None, str | None, True), (None, None | str, True), (None, None | int, True), (None, str | int, False), - (None, None | List[Dict[str, int]], True), + (None, None | typing.List[typing.Dict[str, int]], True), ], ) def test_validate_type_optional(value, expected_type, expected_result): @@ -146,14 +146,14 @@ def test_validate_type_optional(value, expected_type, expected_result): @pytest.mark.parametrize( "value, expected_type, expected_result", [ - (42, List[int], False), - ([1, 2, 3], List[int], True), - ([1, 2, 3], List[str], False), - (["a", "b", "c"], List[str], True), - ([1, "a", 3], List[int], False), - (42, List[int], False), - ([1, 2, 3], List[int], True), - ([1, "2", 3], List[int], False), + (42, typing.List[int], False), + ([1, 2, 3], typing.List[int], True), + ([1, 2, 3], typing.List[str], False), + (["a", "b", "c"], typing.List[str], True), + ([1, "a", 3], typing.List[int], False), + (42, typing.List[int], False), + ([1, 2, 3], typing.List[int], True), + ([1, "2", 3], typing.List[int], False), ], ) def test_validate_type_list(value, expected_type, expected_result): @@ -168,22 +168,22 @@ def test_validate_type_list(value, expected_type, expected_result): @pytest.mark.parametrize( "value, expected_type, expected_result", [ - (42, dict[str, int], False), - ({"a": 1, "b": 2}, dict[str, int], True), - ({"a": 1, "b": 2}, dict[int, str], False), - (42, Dict[str, int], False), - ({"a": 1, "b": 2}, Dict[str, int], True), - ({"a": 1, "b": 2}, Dict[int, str], False), - ({1: "a", 2: "b"}, Dict[int, str], True), - ({1: "a", 2: "b"}, Dict[str, int], False), - ({"a": 1, "b": "c"}, Dict[str, int], False), - ([("a", 1), ("b", 2)], Dict[str, int], False), - ({"key": "value"}, Dict[str, str], True), - ({"key": 2}, Dict[str, str], False), - ({"key": 2}, Dict[str, int], True), - ({"key": 2.0}, Dict[str, int], False), - ({"a": 1, "b": 2}, Dict[str, int], True), - ({"a": 1, "b": "2"}, Dict[str, int], False), + (42, typing.Dict[str, int], False), + ({"a": 1, "b": 2}, typing.Dict[str, int], True), + ({"a": 1, "b": 2}, typing.Dict[int, str], False), + (42, typing.Dict[str, int], False), + ({"a": 1, "b": 2}, typing.Dict[str, int], True), + ({"a": 1, "b": 2}, typing.Dict[int, str], False), + ({1: "a", 2: "b"}, typing.Dict[int, str], True), + ({1: "a", 2: "b"}, typing.Dict[str, int], False), + ({"a": 1, "b": "c"}, typing.Dict[str, int], False), + ([("a", 1), ("b", 2)], typing.Dict[str, int], False), + ({"key": "value"}, typing.Dict[str, str], True), + ({"key": 2}, typing.Dict[str, str], False), + ({"key": 2}, typing.Dict[str, int], True), + ({"key": 2.0}, typing.Dict[str, int], False), + ({"a": 1, "b": 2}, typing.Dict[str, int], True), + ({"a": 1, "b": "2"}, typing.Dict[str, int], False), ], ) def test_validate_type_dict(value, expected_type, expected_result): @@ -198,18 +198,18 @@ def test_validate_type_dict(value, expected_type, expected_result): @pytest.mark.parametrize( "value, expected_type, expected_result", [ - (42, set[int], False), - ({1, 2, 3}, set[int], True), - (42, Set[int], False), - ({1, 2, 3}, Set[int], True), - ({1, 2, 3}, Set[str], False), - ({"a", "b", "c"}, Set[str], True), - ({1, "a", 3}, Set[int], False), - (42, Set[int], False), - ({1, 2, 3}, Set[int], True), - ({1, "2", 3}, Set[int], False), - ([1, 2, 3], Set[int], False), - ("hello", Set[str], False), + (42, typing.Set[int], False), + ({1, 2, 3}, typing.Set[int], True), + (42, typing.Set[int], False), + ({1, 2, 3}, typing.Set[int], True), + ({1, 2, 3}, typing.Set[str], False), + ({"a", "b", "c"}, typing.Set[str], True), + ({1, "a", 3}, typing.Set[int], False), + (42, typing.Set[int], False), + ({1, 2, 3}, typing.Set[int], True), + ({1, "2", 3}, typing.Set[int], False), + ([1, 2, 3], typing.Set[int], False), + ("hello", typing.Set[str], False), ], ) def test_validate_type_set(value, expected_type, expected_result): @@ -224,21 +224,21 @@ def test_validate_type_set(value, expected_type, expected_result): @pytest.mark.parametrize( "value, expected_type, expected_result", [ - (42, tuple[int, str], False), - ((1, "a"), tuple[int, str], True), - (42, Tuple[int, str], False), - ((1, "a"), Tuple[int, str], True), - ((1, 2), Tuple[int, str], False), - ((1, 2), Tuple[int, int], True), - ((1, 2, 3), Tuple[int, int], False), - ((1, "a", 3.14), Tuple[int, str, float], True), - (("a", "b", "c"), Tuple[str, str, str], True), - ((1, "a", 3.14), Tuple[int, str], False), - ((1, "a", 3.14), Tuple[int, str, float], True), - ([1, "a", 3.14], Tuple[int, str, float], False), + (42, typing.Tuple[int, str], False), + ((1, "a"), typing.Tuple[int, str], True), + (42, typing.Tuple[int, str], False), + ((1, "a"), typing.Tuple[int, str], True), + ((1, 2), typing.Tuple[int, str], False), + ((1, 2), typing.Tuple[int, int], True), + ((1, 2, 3), typing.Tuple[int, int], False), + ((1, "a", 3.14), typing.Tuple[int, str, float], True), + (("a", "b", "c"), typing.Tuple[str, str, str], True), + ((1, "a", 3.14), typing.Tuple[int, str], False), + ((1, "a", 3.14), typing.Tuple[int, str, float], True), + ([1, "a", 3.14], typing.Tuple[int, str, float], False), ( (1, "a", 3.14, "b", True, None, (1, 2, 3)), - Tuple[int, str, float, str, bool, type(None), Tuple[int, int, int]], + typing.Tuple[int, str, float, str, bool, type(None), typing.Tuple[int, int, int]], True, ), ], @@ -258,7 +258,7 @@ def test_validate_type_tuple(value, expected_type, expected_result): (43, typing.Callable), (lambda x: x, typing.Callable), (42, typing.Callable[[], None]), - (42, typing.Callable[[int, str], list]), + (42, typing.Callable[[int, str], typing.List]), ], ) def test_validate_type_unsupported_type_hint(value, expected_type): @@ -270,14 +270,14 @@ def test_validate_type_unsupported_type_hint(value, expected_type): @pytest.mark.parametrize( "value, expected_type", [ - (42, list[int, str]), - ([1, 2, 3], list[int, str]), - ({"a": 1, "b": 2}, set[str, int]), - ({1: "a", 2: "b"}, set[int, str]), - ({"a": 1, "b": 2}, set[str, int, str]), - ({1: "a", 2: "b"}, set[int, str, int]), - ({1, 2, 3}, set[int, str]), - ({"a"}, set[int, str]), + (42, typing.List[int, str]), + ([1, 2, 3], typing.List[int, str]), + ({"a": 1, "b": 2}, typing.Set[str, int]), + ({1: "a", 2: "b"}, typing.Set[int, str]), + ({"a": 1, "b": 2}, typing.Set[str, int, str]), + ({1: "a", 2: "b"}, typing.Set[int, str, int]), + ({1, 2, 3}, typing.Set[int, str]), + ({"a"}, typing.Set[int, str]), ], ) def test_validate_type_unsupported_generic_alias(value, expected_type): @@ -289,14 +289,14 @@ def test_validate_type_unsupported_generic_alias(value, expected_type): @pytest.mark.parametrize( "value, expected_type, expected_result", [ - ([1, 2, 3], List[int], True), - (["a", "b", "c"], List[str], True), - ([1, "a", 3], List[int], False), - ([1, 2, [3, 4]], List[Union[int, List[int]]], True), - ([(1, 2), (3, 4)], List[Tuple[int, int]], True), - ([(1, 2), (3, "4")], List[Tuple[int, int]], False), - ({1: [1, 2], 2: [3, 4]}, Dict[int, List[int]], True), - ({1: [1, 2], 2: [3, "4"]}, Dict[int, List[int]], False), + ([1, 2, 3], typing.List[int], True), + (["a", "b", "c"], typing.List[str], True), + ([1, "a", 3], typing.List[int], False), + ([1, 2, [3, 4]], typing.List[Union[int, typing.List[int]]], True), + ([(1, 2), (3, 4)], typing.List[typing.Tuple[int, int]], True), + ([(1, 2), (3, "4")], typing.List[typing.Tuple[int, int]], False), + ({1: [1, 2], 2: [3, 4]}, typing.Dict[int, typing.List[int]], True), + ({1: [1, 2], 2: [3, "4"]}, typing.Dict[int, typing.List[int]], False), ], ) def test_validate_type_collections(value, expected_type, expected_result): @@ -312,20 +312,20 @@ def test_validate_type_collections(value, expected_type, expected_result): "value, expected_type, expected_result", [ # empty lists - ([], List[int], True), - ([], list[dict], True), - ([], list[tuple[dict[tuple, str], str, None]], True), + ([], typing.List[int], True), + ([], typing.List[typing.Dict], True), + ([], typing.List[typing.Tuple[typing.Dict[typing.Tuple, str], str, None]], True), # empty dicts - ({}, Dict[str, int], True), - ({}, dict[str, dict], True), - ({}, dict[str, dict[str, int]], True), - ({}, dict[str, dict[str, int]], True), + ({}, typing.Dict[str, int], True), + ({}, typing.Dict[str, typing.Dict], True), + ({}, typing.Dict[str, typing.Dict[str, int]], True), + ({}, typing.Dict[str, typing.Dict[str, int]], True), # empty sets - (set(), Set[int], True), - (set(), set[dict], True), - (set(), set[tuple[dict[tuple, str], str, None]], True), + (set(), typing.Set[int], True), + (set(), typing.Set[typing.Dict], True), + (set(), typing.Set[typing.Tuple[typing.Dict[typing.Tuple, str], str, None]], True), # empty tuple - (tuple(), tuple, True), + (tuple(), typing.Tuple, True), # empty string ("", str, True), # empty bytes @@ -349,8 +349,8 @@ def test_validate_type_collections(value, expected_type, expected_result): (float("-inf"), float, True), (float(0), float, True), # list/tuple - ([1], tuple[int, int], False), - ((1, 2), list[int], False), + ([1], typing.Tuple[int, int], False), + ((1, 2), typing.List[int], False), ], ) def test_validate_type_edge_cases(value, expected_type, expected_result): @@ -365,17 +365,17 @@ def test_validate_type_edge_cases(value, expected_type, expected_result): @pytest.mark.parametrize( "value, expected_type, expected_result", [ - (42, list[int], False), + (42, typing.List[int], False), ([1, 2, 3], int, False), - (3.14, tuple[float], False), - (3.14, tuple[float, float], False), - (3.14, tuple[bool, str], False), - (False, tuple[bool, str], False), - (False, tuple[bool], False), - ((False,), tuple[bool], True), - (("abc",), tuple[str], True), - ("test-dict", dict[str, int], False), - ("test-dict", dict, False), + (3.14, typing.Tuple[float], False), + (3.14, typing.Tuple[float, float], False), + (3.14, typing.Tuple[bool, str], False), + (False, typing.Tuple[bool, str], False), + (False, typing.Tuple[bool], False), + ((False,), typing.Tuple[bool], True), + (("abc",), typing.Tuple[str], True), + ("test-dict", typing.Dict[str, int], False), + ("test-dict", typing.Dict, False), ], ) def test_validate_type_wrong_type(value, expected_type, expected_result): @@ -388,10 +388,10 @@ def test_validate_type_wrong_type(value, expected_type, expected_result): def test_validate_type_complex(): - assert validate_type([1, 2, [3, 4]], List[Union[int, List[int]]]) - assert validate_type({"a": 1, "b": {"c": 2}}, Dict[str, Union[int, Dict[str, int]]]) - assert validate_type({1, (2, 3)}, Set[Union[int, Tuple[int, int]]]) - assert validate_type((1, ("a", "b")), Tuple[int, Tuple[str, str]]) + assert validate_type([1, 2, [3, 4]], typing.List[Union[int, typing.List[int]]]) + assert validate_type({"a": 1, "b": {"c": 2}}, typing.Dict[str, Union[int, typing.Dict[str, int]]]) + assert validate_type({1, (2, 3)}, typing.Set[Union[int, typing.Tuple[int, int]]]) + assert validate_type((1, ("a", "b")), typing.Tuple[int, typing.Tuple[str, str]]) assert validate_type([{"key": "value"}], typing.List[typing.Dict[str, str]]) assert validate_type([{"key": 2}], typing.List[typing.Dict[str, str]]) == False assert validate_type([[1, 2], [3, 4]], typing.List[typing.List[int]]) @@ -421,14 +421,14 @@ def test_validate_type_complex(): @pytest.mark.parametrize( "value, expected_type, expected_result", [ - ([[[[1]]]], List[List[List[List[int]]]], True), - ([[[[1]]]], List[List[List[List[str]]]], False), - ({"a": {"b": {"c": 1}}}, Dict[str, Dict[str, Dict[str, int]]], True), - ({"a": {"b": {"c": 1}}}, Dict[str, Dict[str, Dict[str, str]]], False), - ({1, 2, 3}, Set[int], True), - ({1, 2, 3}, Set[str], False), - (((1, 2), (3, 4)), Tuple[Tuple[int, int], Tuple[int, int]], True), - (((1, 2), (3, 4)), Tuple[Tuple[int, int], Tuple[int, str]], False), + ([[[[1]]]], typing.List[typing.List[typing.List[typing.List[int]]]], True), + ([[[[1]]]], typing.List[typing.List[typing.List[typing.List[str]]]], False), + ({"a": {"b": {"c": 1}}}, typing.Dict[str, typing.Dict[str, typing.Dict[str, int]]], True), + ({"a": {"b": {"c": 1}}}, typing.Dict[str, typing.Dict[str, typing.Dict[str, str]]], False), + ({1, 2, 3}, typing.Set[int], True), + ({1, 2, 3}, typing.Set[str], False), + (((1, 2), (3, 4)), typing.Tuple[typing.Tuple[int, int], typing.Tuple[int, int]], True), + (((1, 2), (3, 4)), typing.Tuple[typing.Tuple[int, int], typing.Tuple[int, str]], False), ], ) def test_validate_type_nested(value, expected_type, expected_result): From b15f85c0aee1182f09f7b78bac23d91d8078d281 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 01:37:05 -0700 Subject: [PATCH 097/158] moved all special union with pipe to other file --- .../unit/test_validate/test_validate_type.py | 16 -------- .../test_validate_type_unions_special.py | 38 +++++++++++++++++++ 2 files changed, 38 insertions(+), 16 deletions(-) create mode 100644 tests/unit/test_validate/test_validate_type_unions_special.py diff --git a/tests/unit/test_validate/test_validate_type.py b/tests/unit/test_validate/test_validate_type.py index c535c15c..00bf7b9f 100644 --- a/tests/unit/test_validate/test_validate_type.py +++ b/tests/unit/test_validate/test_validate_type.py @@ -93,12 +93,6 @@ def test_validate_type_any(value): (5, Union[int, str], True), ("hello", Union[int, str], True), (5.0, Union[int, str], False), - (5, int | str, True), - ("hello", int | str, True), - (5.0, int | str, False), - (None, typing.Union[int, type(None)], True), - (None, typing.Union[int, str], False), - (None, int | str, False), ], ) def test_validate_type_union(value, expected_type, expected_result): @@ -122,16 +116,6 @@ def test_validate_type_union(value, expected_type, expected_result): (None, Optional[int], True), (None, Optional[None], True), (None, Optional[typing.List[typing.Dict[str, int]]], True), - (42, int | None, True), - ("hello", int | None, False), - (3.14, int | None, False), - ([1], typing.List[int] | None, True), - (None, int | None, True), - (None, str | None, True), - (None, None | str, True), - (None, None | int, True), - (None, str | int, False), - (None, None | typing.List[typing.Dict[str, int]], True), ], ) def test_validate_type_optional(value, expected_type, expected_result): diff --git a/tests/unit/test_validate/test_validate_type_unions_special.py b/tests/unit/test_validate/test_validate_type_unions_special.py new file mode 100644 index 00000000..11feb60e --- /dev/null +++ b/tests/unit/test_validate/test_validate_type_unions_special.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import typing +from typing import Any, Optional, Union + +import pytest + +from muutils.validate_type import validate_type + + +@pytest.mark.parametrize( + "value, expected_type, expected_result", + [ + (5, int | str, True), + ("hello", int | str, True), + (5.0, int | str, False), + (None, typing.Union[int, type(None)], True), + (None, typing.Union[int, str], False), + (None, int | str, False), + (42, int | None, True), + ("hello", int | None, False), + (3.14, int | None, False), + ([1], typing.List[int] | None, True), + (None, int | None, True), + (None, str | None, True), + (None, None | str, True), + (None, None | int, True), + (None, str | int, False), + (None, None | typing.List[typing.Dict[str, int]], True), + ], +) +def test_validate_type_union(value, expected_type, expected_result): + try: + assert validate_type(value, expected_type) == expected_result + except Exception as e: + raise Exception( + f"{value = }, {expected_type = }, {expected_result = }, {e}" + ) from e From 3003e7cc4a72732de38f7c53d8b4c79eeec7e66a Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 01:37:29 -0700 Subject: [PATCH 098/158] folder rename --- tests/unit/{test_validate => validate_type}/test_validate_type.py | 0 .../test_validate_type_unions_special.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename tests/unit/{test_validate => validate_type}/test_validate_type.py (100%) rename tests/unit/{test_validate => validate_type}/test_validate_type_unions_special.py (100%) diff --git a/tests/unit/test_validate/test_validate_type.py b/tests/unit/validate_type/test_validate_type.py similarity index 100% rename from tests/unit/test_validate/test_validate_type.py rename to tests/unit/validate_type/test_validate_type.py diff --git a/tests/unit/test_validate/test_validate_type_unions_special.py b/tests/unit/validate_type/test_validate_type_unions_special.py similarity index 100% rename from tests/unit/test_validate/test_validate_type_unions_special.py rename to tests/unit/validate_type/test_validate_type_unions_special.py From d020a5fd28f960b99b5f5e1c7180642ad600365b Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 01:40:25 -0700 Subject: [PATCH 099/158] moved some invalid tests over --- tests/unit/validate_type/test_validate_type.py | 18 ------------------ .../test_validate_type_unions_special.py | 18 ++++++++++++++++++ 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/unit/validate_type/test_validate_type.py b/tests/unit/validate_type/test_validate_type.py index 00bf7b9f..77e1b5eb 100644 --- a/tests/unit/validate_type/test_validate_type.py +++ b/tests/unit/validate_type/test_validate_type.py @@ -251,24 +251,6 @@ def test_validate_type_unsupported_type_hint(value, expected_type): print(f"Failed to except: {value = }, {expected_type = }") -@pytest.mark.parametrize( - "value, expected_type", - [ - (42, typing.List[int, str]), - ([1, 2, 3], typing.List[int, str]), - ({"a": 1, "b": 2}, typing.Set[str, int]), - ({1: "a", 2: "b"}, typing.Set[int, str]), - ({"a": 1, "b": 2}, typing.Set[str, int, str]), - ({1: "a", 2: "b"}, typing.Set[int, str, int]), - ({1, 2, 3}, typing.Set[int, str]), - ({"a"}, typing.Set[int, str]), - ], -) -def test_validate_type_unsupported_generic_alias(value, expected_type): - with pytest.raises(TypeError): - validate_type(value, expected_type) - print(f"Failed to except: {value = }, {expected_type = }") - @pytest.mark.parametrize( "value, expected_type, expected_result", diff --git a/tests/unit/validate_type/test_validate_type_unions_special.py b/tests/unit/validate_type/test_validate_type_unions_special.py index 11feb60e..f27098ce 100644 --- a/tests/unit/validate_type/test_validate_type_unions_special.py +++ b/tests/unit/validate_type/test_validate_type_unions_special.py @@ -36,3 +36,21 @@ def test_validate_type_union(value, expected_type, expected_result): raise Exception( f"{value = }, {expected_type = }, {expected_result = }, {e}" ) from e + +@pytest.mark.parametrize( + "value, expected_type", + [ + (42, typing.List[int, str]), + ([1, 2, 3], typing.List[int, str]), + ({"a": 1, "b": 2}, typing.Set[str, int]), + ({1: "a", 2: "b"}, typing.Set[int, str]), + ({"a": 1, "b": 2}, typing.Set[str, int, str]), + ({1: "a", 2: "b"}, typing.Set[int, str, int]), + ({1, 2, 3}, typing.Set[int, str]), + ({"a"}, typing.Set[int, str]), + ], +) +def test_validate_type_unsupported_generic_alias(value, expected_type): + with pytest.raises(TypeError): + validate_type(value, expected_type) + print(f"Failed to except: {value = }, {expected_type = }") \ No newline at end of file From af8717e7b0950b7ba16e8ff180bc31069495628d Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 01:56:40 -0700 Subject: [PATCH 100/158] makefile now detects version, goes into compatibility mode --- .github/workflows/checks.yml | 7 +---- makefile | 27 ++++++++++++------- ...ecial.py => test_validate_type_special.py} | 0 3 files changed, 18 insertions(+), 16 deletions(-) rename tests/unit/validate_type/{test_validate_type_unions_special.py => test_validate_type_special.py} (100%) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 1ca2ca0c..99e214fa 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -118,9 +118,4 @@ jobs: run: make test WARN_STRICT=1 RUN_GLOBAL=1 - name: check typing - if: ${{ matrix.versions.python != '3.8' && matrix.versions.python != '3.9' }} - run: make typing RUN_GLOBAL=1 - - - name: check typing in compatibility mode - if: ${{ matrix.versions.python == '3.8' || matrix.versions.python == '3.9' }} - run: make typing-compat RUN_GLOBAL=1 \ No newline at end of file + run: make typing RUN_GLOBAL=1 \ No newline at end of file diff --git a/makefile b/makefile index 2d790652..54e194e0 100644 --- a/makefile +++ b/makefile @@ -21,8 +21,6 @@ COMMIT_LOG_FILE := .commit_log COMMIT_LOG_SINCE_LAST_VERSION := $(shell (git log $(LAST_VERSION)..HEAD --pretty=format:"- %s (%h)" | tr '`' "'" ; echo) | tac | tr '\n' '\t') # 1 2 3 4 5 -TYPECHECK_COMPAT_ARGS := --disable-error-code misc --disable-error-code syntax --disable-error-code import-not-found - .PHONY: default default: help @@ -47,6 +45,11 @@ else PYTHON = $(PYTHON_BASE) endif +PYTHON_VERSION := $(shell $(PYTHON) -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')") + +COMPATIBILITY_MODE := $(shell $(PYTHON) -c "import sys; print(1 if sys.version_info < (3, 10) else 0)") + +TYPECHECK_ARGS ?= # formatting # -------------------------------------------------- @@ -82,6 +85,14 @@ ifneq ($(WARN_STRICT), 0) PYTEST_OPTIONS += -W error endif +# compatibility mode for python <3.10 + +# Update the PYTEST_OPTIONS to include the conditional ignore option +# @echo "WARNING: Detected python version less than 3.10, some behavior will be different" +ifeq ($(COMPATIBILITY_MODE), 1) + PYTEST_OPTIONS += --ignore=tests/unit/validate_type/test_validate_type_special.py + TYPECHECK_ARGS += --disable-error-code misc --disable-error-code syntax --disable-error-code import-not-found +endif .PHONY: cov cov: @@ -101,17 +112,12 @@ cov: .PHONY: typing typing: clean @echo "running type checks" - $(PYTHON) -m mypy --config-file $(PYPROJECT) $(PACKAGE_NAME)/ - $(PYTHON) -m mypy --config-file $(PYPROJECT) tests/ + $(PYTHON) -m mypy --config-file $(PYPROJECT) $(TYPECHECK_ARGS) $(PACKAGE_NAME)/ + $(PYTHON) -m mypy --config-file $(PYPROJECT) $(TYPECHECK_ARGS) tests/ -.PHONY: typing-compat -typing-compat: clean - @echo "running type checks in compatibility mode for older python versions" - $(PYTHON) -m mypy --config-file $(PYPROJECT) $(TYPECHECK_COMPAT_ARGS) $(PACKAGE_NAME)/ - $(PYTHON) -m mypy --config-file $(PYPROJECT) $(TYPECHECK_COMPAT_ARGS) tests/ .PHONY: test -test: clean +test: @echo "running tests" $(PYTHON) -m pytest $(PYTEST_OPTIONS) $(TESTS_DIR) @@ -206,6 +212,7 @@ help: @cat Makefile | sed -n '/^\.PHONY: / h; /\(^\t@*echo\|^\t:\)/ {H; x; /PHONY/ s/.PHONY: \(.*\)\n.*"\(.*\)"/ make \1\t\2/p; d; x}'| sort -k2,2 |expand -t 25 @echo "# makefile variables:" @echo " PYTHON = $(PYTHON)" + @echo " PYTHON_VERSION = $(PYTHON_VERSION)" @echo " PACKAGE_NAME = $(PACKAGE_NAME)" @echo " VERSION = $(VERSION)" @echo " LAST_VERSION = $(LAST_VERSION)" diff --git a/tests/unit/validate_type/test_validate_type_unions_special.py b/tests/unit/validate_type/test_validate_type_special.py similarity index 100% rename from tests/unit/validate_type/test_validate_type_unions_special.py rename to tests/unit/validate_type/test_validate_type_special.py From 1c3b824162730ac00e76bf1f6075eb29b708898c Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 02:02:28 -0700 Subject: [PATCH 101/158] make format --- makefile | 2 +- muutils/validate_type.py | 7 ++- .../unit/validate_type/test_validate_type.py | 45 +++++++++++++++---- .../test_validate_type_special.py | 3 +- 4 files changed, 45 insertions(+), 12 deletions(-) diff --git a/makefile b/makefile index 54e194e0..46d1edf4 100644 --- a/makefile +++ b/makefile @@ -90,7 +90,7 @@ endif # Update the PYTEST_OPTIONS to include the conditional ignore option # @echo "WARNING: Detected python version less than 3.10, some behavior will be different" ifeq ($(COMPATIBILITY_MODE), 1) - PYTEST_OPTIONS += --ignore=tests/unit/validate_type/test_validate_type_special.py + PYTEST_OPTIONS += --ignore=tests/unit/validate_type/ TYPECHECK_ARGS += --disable-error-code misc --disable-error-code syntax --disable-error-code import-not-found endif diff --git a/muutils/validate_type.py b/muutils/validate_type.py index a9c31c0f..da9b3131 100644 --- a/muutils/validate_type.py +++ b/muutils/validate_type.py @@ -25,8 +25,13 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: # useful for debugging # print(f"{value = }, {expected_type = }, {origin = }, {args = }") + UnionType = getattr(types, "UnionType", None) - if origin is types.UnionType or origin is typing.Union: + if (origin is typing.Union) or ( # this works in python <3.10 + False + if UnionType is None # return False if UnionType is not available + else origin is UnionType # return True if UnionType is available + ): return any(validate_type(value, arg) for arg in args) # generic alias, more complicated diff --git a/tests/unit/validate_type/test_validate_type.py b/tests/unit/validate_type/test_validate_type.py index 77e1b5eb..de629dfb 100644 --- a/tests/unit/validate_type/test_validate_type.py +++ b/tests/unit/validate_type/test_validate_type.py @@ -222,7 +222,9 @@ def test_validate_type_set(value, expected_type, expected_result): ([1, "a", 3.14], typing.Tuple[int, str, float], False), ( (1, "a", 3.14, "b", True, None, (1, 2, 3)), - typing.Tuple[int, str, float, str, bool, type(None), typing.Tuple[int, int, int]], + typing.Tuple[ + int, str, float, str, bool, type(None), typing.Tuple[int, int, int] + ], True, ), ], @@ -251,7 +253,6 @@ def test_validate_type_unsupported_type_hint(value, expected_type): print(f"Failed to except: {value = }, {expected_type = }") - @pytest.mark.parametrize( "value, expected_type, expected_result", [ @@ -280,7 +281,11 @@ def test_validate_type_collections(value, expected_type, expected_result): # empty lists ([], typing.List[int], True), ([], typing.List[typing.Dict], True), - ([], typing.List[typing.Tuple[typing.Dict[typing.Tuple, str], str, None]], True), + ( + [], + typing.List[typing.Tuple[typing.Dict[typing.Tuple, str], str, None]], + True, + ), # empty dicts ({}, typing.Dict[str, int], True), ({}, typing.Dict[str, typing.Dict], True), @@ -289,7 +294,11 @@ def test_validate_type_collections(value, expected_type, expected_result): # empty sets (set(), typing.Set[int], True), (set(), typing.Set[typing.Dict], True), - (set(), typing.Set[typing.Tuple[typing.Dict[typing.Tuple, str], str, None]], True), + ( + set(), + typing.Set[typing.Tuple[typing.Dict[typing.Tuple, str], str, None]], + True, + ), # empty tuple (tuple(), typing.Tuple, True), # empty string @@ -355,7 +364,9 @@ def test_validate_type_wrong_type(value, expected_type, expected_result): def test_validate_type_complex(): assert validate_type([1, 2, [3, 4]], typing.List[Union[int, typing.List[int]]]) - assert validate_type({"a": 1, "b": {"c": 2}}, typing.Dict[str, Union[int, typing.Dict[str, int]]]) + assert validate_type( + {"a": 1, "b": {"c": 2}}, typing.Dict[str, Union[int, typing.Dict[str, int]]] + ) assert validate_type({1, (2, 3)}, typing.Set[Union[int, typing.Tuple[int, int]]]) assert validate_type((1, ("a", "b")), typing.Tuple[int, typing.Tuple[str, str]]) assert validate_type([{"key": "value"}], typing.List[typing.Dict[str, str]]) @@ -389,12 +400,28 @@ def test_validate_type_complex(): [ ([[[[1]]]], typing.List[typing.List[typing.List[typing.List[int]]]], True), ([[[[1]]]], typing.List[typing.List[typing.List[typing.List[str]]]], False), - ({"a": {"b": {"c": 1}}}, typing.Dict[str, typing.Dict[str, typing.Dict[str, int]]], True), - ({"a": {"b": {"c": 1}}}, typing.Dict[str, typing.Dict[str, typing.Dict[str, str]]], False), + ( + {"a": {"b": {"c": 1}}}, + typing.Dict[str, typing.Dict[str, typing.Dict[str, int]]], + True, + ), + ( + {"a": {"b": {"c": 1}}}, + typing.Dict[str, typing.Dict[str, typing.Dict[str, str]]], + False, + ), ({1, 2, 3}, typing.Set[int], True), ({1, 2, 3}, typing.Set[str], False), - (((1, 2), (3, 4)), typing.Tuple[typing.Tuple[int, int], typing.Tuple[int, int]], True), - (((1, 2), (3, 4)), typing.Tuple[typing.Tuple[int, int], typing.Tuple[int, str]], False), + ( + ((1, 2), (3, 4)), + typing.Tuple[typing.Tuple[int, int], typing.Tuple[int, int]], + True, + ), + ( + ((1, 2), (3, 4)), + typing.Tuple[typing.Tuple[int, int], typing.Tuple[int, str]], + False, + ), ], ) def test_validate_type_nested(value, expected_type, expected_result): diff --git a/tests/unit/validate_type/test_validate_type_special.py b/tests/unit/validate_type/test_validate_type_special.py index f27098ce..e8f6803a 100644 --- a/tests/unit/validate_type/test_validate_type_special.py +++ b/tests/unit/validate_type/test_validate_type_special.py @@ -37,6 +37,7 @@ def test_validate_type_union(value, expected_type, expected_result): f"{value = }, {expected_type = }, {expected_result = }, {e}" ) from e + @pytest.mark.parametrize( "value, expected_type", [ @@ -53,4 +54,4 @@ def test_validate_type_union(value, expected_type, expected_result): def test_validate_type_unsupported_generic_alias(value, expected_type): with pytest.raises(TypeError): validate_type(value, expected_type) - print(f"Failed to except: {value = }, {expected_type = }") \ No newline at end of file + print(f"Failed to except: {value = }, {expected_type = }") From 32062f06052134e682a85376ca05f017ede65215 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 02:04:01 -0700 Subject: [PATCH 102/158] update makefile --- makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/makefile b/makefile index 46d1edf4..9209de6b 100644 --- a/makefile +++ b/makefile @@ -88,8 +88,8 @@ endif # compatibility mode for python <3.10 # Update the PYTEST_OPTIONS to include the conditional ignore option -# @echo "WARNING: Detected python version less than 3.10, some behavior will be different" ifeq ($(COMPATIBILITY_MODE), 1) + JUNK := $(info WARNING: Detected python version less than 3.10, some behavior will be different) PYTEST_OPTIONS += --ignore=tests/unit/validate_type/ TYPECHECK_ARGS += --disable-error-code misc --disable-error-code syntax --disable-error-code import-not-found endif @@ -117,7 +117,7 @@ typing: clean .PHONY: test -test: +test: clean @echo "running tests" $(PYTHON) -m pytest $(PYTEST_OPTIONS) $(TESTS_DIR) From f60fe405b629dc04ec0492999021588645c8ed6c Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 02:23:31 -0700 Subject: [PATCH 103/158] cursed --- .../json_serialize/serializable_dataclass.py | 13 ++++-- muutils/validate_type.py | 41 ++++++++++++------- .../test_sdc_defaults.py | 2 +- .../test_sdc_properties_nested.py | 2 +- .../test_serializable_dataclass.py | 16 +++++--- 5 files changed, 49 insertions(+), 25 deletions(-) diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index 1a5dbe06..61a37209 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -398,6 +398,10 @@ def __deepcopy__(self, memo: dict) -> "SerializableDataclass": return self.__class__.load(self.serialize()) +class CantGetTypeHintsWarning(UserWarning): + pass + + # Step 3: Create a custom serializable_dataclass decorator # TODO: add a kwarg for always asserting type for all fields def serializable_dataclass( @@ -522,7 +526,8 @@ def load(cls, data: dict[str, Any] | T) -> Type[T]: + " - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x)\n" + " - use python 3.9.x or higher\n" + " - add explicit loading functions to the fields\n" - + f" {dataclasses.fields(cls) = }" + + f" {dataclasses.fields(cls) = }", + CantGetTypeHintsWarning, ) cls_type_hints = dict() else: @@ -588,7 +593,8 @@ def load(cls, data: dict[str, Any] | T) -> Type[T]: ) from e else: raise ValueError( - f"Cannot get type hints for {cls.__name__}, and so cannot validate. Python version is {sys.version_info = }. You can:\n" + f"Cannot get type hints for {cls.__name__}, field {field.name = } and so cannot validate." + + f"Python version is {sys.version_info = }. You can:\n" + f" - disable `assert_type`. Currently: {field.assert_type = }\n" + f" - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {field.type = }\n" + " - use python 3.9.x or higher\n" @@ -597,7 +603,8 @@ def load(cls, data: dict[str, Any] | T) -> Type[T]: else: # TODO: raise an exception here? Can't validate if data given warnings.warn( - f"Field '{field.name}' on class {cls} has no type hint, but {field.assert_type = }\n{field = }\n{cls_type_hints = }\n{data = }" + f"Field '{field.name}' on class {cls} has no type hint, but {field.assert_type = }\n{field = }\n{cls_type_hints = }\n{data = }", + CantGetTypeHintsWarning, ) return cls(**ctor_kwargs) diff --git a/muutils/validate_type.py b/muutils/validate_type.py index da9b3131..d8632c77 100644 --- a/muutils/validate_type.py +++ b/muutils/validate_type.py @@ -3,6 +3,20 @@ import types import typing +# this is also for python <3.10 compatibility +_GenericAliasTypeNames: typing.List[str] = [ + "GenericAlias", + "_GenericAlias", + "_UnionGenericAlias", + "_BaseGenericAlias", +] + +_GenericAliasTypesList: list = [ + getattr(typing, name, None) for name in _GenericAliasTypeNames +] + +GenericAliasTypes: tuple = tuple([t for t in _GenericAliasTypesList if t is not None]) + def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: """Validate that a value is of the expected type. use `typeguard` for a more robust solution. @@ -35,15 +49,7 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: return any(validate_type(value, arg) for arg in args) # generic alias, more complicated - if isinstance( - expected_type, - ( - typing.GenericAlias, - typing._GenericAlias, - typing._UnionGenericAlias, - typing._BaseGenericAlias, - ), - ): + if isinstance(expected_type, GenericAliasTypes): if origin is list: # no args @@ -52,7 +58,8 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: # incorrect number of args if len(args) != 1: raise TypeError( - f"Too many arguments for list expected 1, got {args = }, {expected_type = }, {value = }, {origin = }" + f"Too many arguments for list expected 1, got {args = }, {expected_type = }, {value = }, {origin = }", + f"{GenericAliasTypes = }", ) # check is list if not isinstance(value, list): @@ -68,7 +75,8 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: # incorrect number of args if len(args) != 2: raise TypeError( - f"Expected 2 arguments for dict, expected 2, got {args = }, {expected_type = }, {value = }, {origin = }" + f"Expected 2 arguments for dict, expected 2, got {args = }, {expected_type = }, {value = }, {origin = }", + f"{GenericAliasTypes = }", ) # check is dict if not isinstance(value, dict): @@ -88,7 +96,8 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: # incorrect number of args if len(args) != 1: raise TypeError( - f"Expected 1 argument for Set, got {args = }, {expected_type = }, {value = }, {origin = }" + f"Expected 1 argument for Set, got {args = }, {expected_type = }, {value = }, {origin = }", + f"{GenericAliasTypes = }", ) # check is set if not isinstance(value, set): @@ -113,8 +122,12 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: # TODO: Callables, etc. raise ValueError( - f"Unsupported generic alias {expected_type = } for {value = }, {origin = }, {args = }" + f"Unsupported generic alias {expected_type = } for {value = }, {origin = }, {args = }", + f"{GenericAliasTypes = }", ) else: - raise ValueError(f"Unsupported type hint {expected_type = } for {value = }") + raise ValueError( + f"Unsupported type hint {expected_type = } for {value = }", + f"{GenericAliasTypes = }", + ) diff --git a/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py b/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py index 7ea89c7a..0adc6d11 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py @@ -14,7 +14,7 @@ # pylint: disable=missing-class-docstring -BELOW_PY_3_10: bool = sys.version_info < (3, 10) +BELOW_PY_3_10: bool = False # sys.version_info < (3, 10) def _loading_test_wrapper(cls, data, assert_record_len: int | None = None) -> Any: diff --git a/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py b/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py index a1440bf6..e1b0324d 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py @@ -11,7 +11,7 @@ print(f"{SUPPORTS_KW_ONLY = }") -BELOW_PY_3_10: bool = sys.version_info < (3, 10) +BELOW_PY_3_10: bool = False # sys.version_info < (3, 10) def _loading_test_wrapper(cls, data, assert_record_len: int | None = None) -> Any: diff --git a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py index 31ecd1bc..d8adb43b 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py @@ -1,6 +1,7 @@ from __future__ import annotations import sys +import typing from typing import Any import pytest @@ -14,7 +15,8 @@ # pylint: disable=missing-class-docstring, unused-variable -BELOW_PY_3_10: bool = sys.version_info < (3, 10) +BELOW_PY_3_10: bool = False +# sys.version_info < (3, 10) def _loading_test_wrapper(cls, data, assert_record_len: int | None = None) -> Any: @@ -35,7 +37,7 @@ def _loading_test_wrapper(cls, data, assert_record_len: int | None = None) -> An class BasicAutofields(SerializableDataclass): a: str b: int - c: list[int] + c: typing.List[int] def test_basic_auto_fields(): @@ -71,7 +73,7 @@ def test_basic_diff(): class SimpleFields(SerializableDataclass): d: str e: int = 42 - f: list[int] = serializable_field(default_factory=list) + f: typing.List[int] = serializable_field(default_factory=list) # noqa: F821 @serializable_dataclass @@ -126,7 +128,9 @@ def test_simple_fields_serialization(simple_fields_instance): def test_simple_fields_loading(simple_fields_instance): serialized = simple_fields_instance.serialize() - loaded = _loading_test_wrapper(SimpleFields, serialized) # , assert_record_len=4) + loaded = SimpleFields.load( + serialized + ) # _loading_test_wrapper(SimpleFields, serialized) # , assert_record_len=4) assert loaded == simple_fields_instance assert loaded.diff(simple_fields_instance) == {} @@ -267,7 +271,7 @@ def test_person_serialization(): class FullPerson(SerializableDataclass): name: str = serializable_field() age: int = serializable_field(default=-1) - items: list[str] = serializable_field(default_factory=list) + items: typing.List[str] = serializable_field(default_factory=list) @property def full_name(self) -> str: @@ -313,7 +317,7 @@ class CustomSerialization(SerializableDataclass): class Nested_with_Container(SerializableDataclass): val_int: int val_str: str - val_list: list[BasicAutofields] = serializable_field( + val_list: typing.List[BasicAutofields] = serializable_field( default_factory=list, serialization_fn=lambda x: [y.serialize() for y in x], loading_fn=lambda x: [BasicAutofields.load(y) for y in x["val_list"]], From 7a6a0e7a0fe021e666e9d56b21beb094ece61a0e Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 02:25:28 -0700 Subject: [PATCH 104/158] clean up _loading_test_wrapper, remove it entirely later --- .../test_sdc_defaults.py | 17 ++++---------- .../test_sdc_properties_nested.py | 16 +++----------- .../test_serializable_dataclass.py | 22 ++++--------------- 3 files changed, 11 insertions(+), 44 deletions(-) diff --git a/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py b/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py index 0adc6d11..4119c73a 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py @@ -14,21 +14,12 @@ # pylint: disable=missing-class-docstring -BELOW_PY_3_10: bool = False # sys.version_info < (3, 10) - -def _loading_test_wrapper(cls, data, assert_record_len: int | None = None) -> Any: +# TODO: get rid of all _loading_test_wrapper functions across all files +def _loading_test_wrapper(cls, data) -> Any: """wrapper for testing the load function, which accounts for version differences""" - if BELOW_PY_3_10: - with pytest.warns(UserWarning) as record: - loaded = cls.load(data) - print([x.message for x in record]) - if assert_record_len is not None: - assert len(record) == assert_record_len - return loaded - else: - loaded = cls.load(data) - return loaded + loaded = cls.load(data) + return loaded @serializable_dataclass diff --git a/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py b/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py index e1b0324d..35f252fc 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py @@ -11,21 +11,11 @@ print(f"{SUPPORTS_KW_ONLY = }") -BELOW_PY_3_10: bool = False # sys.version_info < (3, 10) - -def _loading_test_wrapper(cls, data, assert_record_len: int | None = None) -> Any: +def _loading_test_wrapper(cls, data) -> Any: """wrapper for testing the load function, which accounts for version differences""" - if BELOW_PY_3_10: - with pytest.warns(UserWarning) as record: - loaded = cls.load(data) - print([x.message for x in record]) - if assert_record_len is not None: - assert len(record) == assert_record_len - return loaded - else: - loaded = cls.load(data) - return loaded + loaded = cls.load(data) + return loaded @serializable_dataclass diff --git a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py index d8adb43b..e646825f 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py @@ -15,22 +15,10 @@ # pylint: disable=missing-class-docstring, unused-variable -BELOW_PY_3_10: bool = False -# sys.version_info < (3, 10) - - -def _loading_test_wrapper(cls, data, assert_record_len: int | None = None) -> Any: +def _loading_test_wrapper(cls, data) -> Any: """wrapper for testing the load function, which accounts for version differences""" - if BELOW_PY_3_10: - with pytest.warns(UserWarning) as record: - loaded = cls.load(data) - print([x.message for x in record]) - if assert_record_len is not None: - assert len(record) == assert_record_len - return loaded - else: - loaded = cls.load(data) - return loaded + loaded = cls.load(data) + return loaded @serializable_dataclass @@ -128,9 +116,7 @@ def test_simple_fields_serialization(simple_fields_instance): def test_simple_fields_loading(simple_fields_instance): serialized = simple_fields_instance.serialize() - loaded = SimpleFields.load( - serialized - ) # _loading_test_wrapper(SimpleFields, serialized) # , assert_record_len=4) + loaded = SimpleFields.load(serialized) assert loaded == simple_fields_instance assert loaded.diff(simple_fields_instance) == {} From 41e43c78335833d49d69ecb115347b2fb4cfa01f Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 02:36:33 -0700 Subject: [PATCH 105/158] fix invalid type check --- .../validate_type/test_validate_type_special.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/unit/validate_type/test_validate_type_special.py b/tests/unit/validate_type/test_validate_type_special.py index e8f6803a..8c276b5e 100644 --- a/tests/unit/validate_type/test_validate_type_special.py +++ b/tests/unit/validate_type/test_validate_type_special.py @@ -41,14 +41,15 @@ def test_validate_type_union(value, expected_type, expected_result): @pytest.mark.parametrize( "value, expected_type", [ - (42, typing.List[int, str]), - ([1, 2, 3], typing.List[int, str]), - ({"a": 1, "b": 2}, typing.Set[str, int]), - ({1: "a", 2: "b"}, typing.Set[int, str]), - ({"a": 1, "b": 2}, typing.Set[str, int, str]), - ({1: "a", 2: "b"}, typing.Set[int, str, int]), - ({1, 2, 3}, typing.Set[int, str]), - ({"a"}, typing.Set[int, str]), + (42, list[int, str]), + ([1, 2, 3], list[int, str]), + ({"a": 1, "b": 2}, set[str, int]), + ({1: "a", 2: "b"}, set[int, str]), + ({"a": 1, "b": 2}, set[str, int, str]), + ({1: "a", 2: "b"}, set[int, str, int]), + ({1, 2, 3}, set[int, str]), + ({"a"}, set[int, str]), + (42, dict[int, str, bool]), ], ) def test_validate_type_unsupported_generic_alias(value, expected_type): From 243e54f27269c325fdddfe712109fc01ef54ec3d Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 02:36:47 -0700 Subject: [PATCH 106/158] cursedgs! --- .gitignore | 8 ++++++++ makefile | 5 +++++ tests/util/replace_type_hints.py | 23 +++++++++++++++++++++++ tests/{manual => util}/test_fire.py | 0 4 files changed, 36 insertions(+) create mode 100644 tests/util/replace_type_hints.py rename tests/{manual => util}/test_fire.py (100%) diff --git a/.gitignore b/.gitignore index 04d04471..87aa3dc4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,17 +1,25 @@ +# this one is cursed +tests/unit/validate_type/_test_validate_type_MODERN.py +# test notebook _test.ipynb +# junk data JUNK_DATA_PATH/ junk_data +# misc .pypi-token .commit_log .vscode/ +# caches __pycache__/ **/__pycache__/ **/.mypy_cache/ **/.pytest_cache/ +# coverage .coverage htmlcov/ +# build build/ dist/ muutils.egg-info/ diff --git a/makefile b/makefile index 9209de6b..c0a9aed5 100644 --- a/makefile +++ b/makefile @@ -119,6 +119,11 @@ typing: clean .PHONY: test test: clean @echo "running tests" + + if [ $(COMPATIBILITY_MODE) -eq 1 ]; then \ + echo "converting certain tests to modern format"; \ + python tests/util/replace_type_hints.py tests/unit/validate_type/test_validate_type.py > tests/unit/validate_type/_test_validate_type_MODERN.py; \ + fi; \ $(PYTHON) -m pytest $(PYTEST_OPTIONS) $(TESTS_DIR) diff --git a/tests/util/replace_type_hints.py b/tests/util/replace_type_hints.py new file mode 100644 index 00000000..cb577591 --- /dev/null +++ b/tests/util/replace_type_hints.py @@ -0,0 +1,23 @@ +def replace_typing_aliases(filename): + # Dictionary to map old types from the typing module to the new built-in types + replacements = { + "typing.List": "list", + "typing.Dict": "dict", + "typing.Set": "set", + "typing.Tuple": "tuple" + } + + # Read the file content + with open(filename, 'r') as file: + content = file.read() + + # Replace all occurrences of the typing module types with built-in types + for old, new in replacements.items(): + content = content.replace(old, new) + + # Print the modified content to stdout + print(content) + +if __name__ == "__main__": + import sys + replace_typing_aliases(sys.argv[1]) \ No newline at end of file diff --git a/tests/manual/test_fire.py b/tests/util/test_fire.py similarity index 100% rename from tests/manual/test_fire.py rename to tests/util/test_fire.py From b49bdec8330dc9ff7b41f7d68956708d1b145940 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 02:42:33 -0700 Subject: [PATCH 107/158] so so so cursed --- .gitignore | 2 +- makefile | 5 +++-- tests/util/replace_type_hints.py | 18 ++++++++++-------- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index 87aa3dc4..2d99a757 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ # this one is cursed -tests/unit/validate_type/_test_validate_type_MODERN.py +tests/unit/validate_type/test_validate_type_MODERN.py # test notebook _test.ipynb # junk data diff --git a/makefile b/makefile index c0a9aed5..8ff3046f 100644 --- a/makefile +++ b/makefile @@ -120,9 +120,9 @@ typing: clean test: clean @echo "running tests" - if [ $(COMPATIBILITY_MODE) -eq 1 ]; then \ + if [ $(COMPATIBILITY_MODE) -eq 0 ]; then \ echo "converting certain tests to modern format"; \ - python tests/util/replace_type_hints.py tests/unit/validate_type/test_validate_type.py > tests/unit/validate_type/_test_validate_type_MODERN.py; \ + $(PYTHON) tests/util/replace_type_hints.py tests/unit/validate_type/test_validate_type.py > tests/unit/validate_type/test_validate_type_MODERN.py; \ fi; \ $(PYTHON) -m pytest $(PYTEST_OPTIONS) $(TESTS_DIR) @@ -207,6 +207,7 @@ clean: rm -rf tests/junk_data python -Bc "import pathlib; [p.unlink() for p in pathlib.Path('.').rglob('*.py[co]')]" python -Bc "import pathlib; [p.rmdir() for p in pathlib.Path('.').rglob('__pycache__')]" + rm -rf tests/unit/validate_type/test_validate_type_MODERN.py # listing targets, from stackoverflow # https://stackoverflow.com/questions/4219255/how-do-you-get-the-list-of-targets-in-a-makefile diff --git a/tests/util/replace_type_hints.py b/tests/util/replace_type_hints.py index cb577591..c024dd18 100644 --- a/tests/util/replace_type_hints.py +++ b/tests/util/replace_type_hints.py @@ -4,20 +4,22 @@ def replace_typing_aliases(filename): "typing.List": "list", "typing.Dict": "dict", "typing.Set": "set", - "typing.Tuple": "tuple" + "typing.Tuple": "tuple", } - + # Read the file content - with open(filename, 'r') as file: + with open(filename, "r") as file: content = file.read() - + # Replace all occurrences of the typing module types with built-in types for old, new in replacements.items(): content = content.replace(old, new) - + # Print the modified content to stdout print(content) - + + if __name__ == "__main__": - import sys - replace_typing_aliases(sys.argv[1]) \ No newline at end of file + import sys + + replace_typing_aliases(sys.argv[1]) From 0f1deed20eeab038833306b5fb4ae9908544c62e Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 02:43:52 -0700 Subject: [PATCH 108/158] wooooo --- .github/workflows/checks.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 99e214fa..ca9b416d 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -115,6 +115,8 @@ jobs: run: make test RUN_GLOBAL=1 - name: tests in strict mode + # until zanj ported to 3.8 and 3.9 + if: ${{ matrix.versions.python != '3.8' && matrix.versions.python != '3.9' }} run: make test WARN_STRICT=1 RUN_GLOBAL=1 - name: check typing From af7236628c4105c5a64b61822d884e7ccbdf2535 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 02:46:08 -0700 Subject: [PATCH 109/158] ! --- .github/workflows/checks.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index ca9b416d..a1d6b461 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -107,7 +107,7 @@ jobs: run: uv pip install . --system - name: Install zanj (>=3.10 only) - # not yet available for python 3.8 and 3.9 + # TODO: not yet available for python 3.8 and 3.9 if: ${{ matrix.versions.python != '3.8' && matrix.versions.python != '3.9' }} run: uv pip install zanj --system @@ -115,9 +115,11 @@ jobs: run: make test RUN_GLOBAL=1 - name: tests in strict mode - # until zanj ported to 3.8 and 3.9 + # TODO: until zanj ported to 3.8 and 3.9 if: ${{ matrix.versions.python != '3.8' && matrix.versions.python != '3.9' }} run: make test WARN_STRICT=1 RUN_GLOBAL=1 - name: check typing + # TODO: idx if this is even possible to fix + if: ${{ matrix.versions.python != '3.8' && matrix.versions.python != '3.9' }} run: make typing RUN_GLOBAL=1 \ No newline at end of file From 05867bce209d61e4c8e2f8cc451f24e460033881 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 02:52:01 -0700 Subject: [PATCH 110/158] fix typing! --- .github/workflows/checks.yml | 2 -- muutils/validate_type.py | 9 +++++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index a1d6b461..7cff1f6e 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -120,6 +120,4 @@ jobs: run: make test WARN_STRICT=1 RUN_GLOBAL=1 - name: check typing - # TODO: idx if this is even possible to fix - if: ${{ matrix.versions.python != '3.8' && matrix.versions.python != '3.9' }} run: make typing RUN_GLOBAL=1 \ No newline at end of file diff --git a/muutils/validate_type.py b/muutils/validate_type.py index d8632c77..ae218e1c 100644 --- a/muutils/validate_type.py +++ b/muutils/validate_type.py @@ -34,8 +34,8 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: except TypeError: pass - origin: type = typing.get_origin(expected_type) - args: list = typing.get_args(expected_type) + origin: typing.Any = typing.get_origin(expected_type) + args: tuple = typing.get_args(expected_type) # useful for debugging # print(f"{value = }, {expected_type = }, {origin = }, {args = }") @@ -49,6 +49,7 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: return any(validate_type(value, arg) for arg in args) # generic alias, more complicated + item_type: type if isinstance(expected_type, GenericAliasTypes): if origin is list: @@ -65,7 +66,7 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: if not isinstance(value, list): return False # check all items in list are of the correct type - item_type: type = args[0] + item_type = args[0] return all(validate_type(item, item_type) for item in value) if origin is dict: @@ -103,7 +104,7 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: if not isinstance(value, set): return False # check all items in set are of the correct type - item_type: type = args[0] + item_type = args[0] return all(validate_type(item, item_type) for item in value) if origin is tuple: From ac96afb8106ce02f2d87a90ad280935d1f5e0725 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 02:53:21 -0700 Subject: [PATCH 111/158] fix typing in some tests literally in the tests for invalid types --- .../test_validate_type_special.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/unit/validate_type/test_validate_type_special.py b/tests/unit/validate_type/test_validate_type_special.py index 8c276b5e..4561fc4d 100644 --- a/tests/unit/validate_type/test_validate_type_special.py +++ b/tests/unit/validate_type/test_validate_type_special.py @@ -41,15 +41,15 @@ def test_validate_type_union(value, expected_type, expected_result): @pytest.mark.parametrize( "value, expected_type", [ - (42, list[int, str]), - ([1, 2, 3], list[int, str]), - ({"a": 1, "b": 2}, set[str, int]), - ({1: "a", 2: "b"}, set[int, str]), - ({"a": 1, "b": 2}, set[str, int, str]), - ({1: "a", 2: "b"}, set[int, str, int]), - ({1, 2, 3}, set[int, str]), - ({"a"}, set[int, str]), - (42, dict[int, str, bool]), + (42, list[int, str]), # type: ignore[misc] + ([1, 2, 3], list[int, str]), # type: ignore[misc] + ({"a": 1, "b": 2}, set[str, int]), # type: ignore[misc] + ({1: "a", 2: "b"}, set[int, str]), # type: ignore[misc] + ({"a": 1, "b": 2}, set[str, int, str]), # type: ignore[misc] + ({1: "a", 2: "b"}, set[int, str, int]), # type: ignore[misc] + ({1, 2, 3}, set[int, str]), # type: ignore[misc] + ({"a"}, set[int, str]), # type: ignore[misc] + (42, dict[int, str, bool]), # type: ignore[misc] ], ) def test_validate_type_unsupported_generic_alias(value, expected_type): From cfb0474ff98775980b6092574a55d0c746ec4205 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 02:55:23 -0700 Subject: [PATCH 112/158] fix more typing, make whole workflow parallel again --- .github/workflows/checks.yml | 2 +- .../test_validate_type_special.py | 32 +++++++++---------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 7cff1f6e..3c4237f4 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -67,7 +67,7 @@ jobs: test: name: Test and Lint runs-on: ubuntu-latest - needs: [lint, check-deps] + # needs: [lint, check-deps] strategy: matrix: versions: diff --git a/tests/unit/validate_type/test_validate_type_special.py b/tests/unit/validate_type/test_validate_type_special.py index 4561fc4d..4ca4f069 100644 --- a/tests/unit/validate_type/test_validate_type_special.py +++ b/tests/unit/validate_type/test_validate_type_special.py @@ -11,22 +11,22 @@ @pytest.mark.parametrize( "value, expected_type, expected_result", [ - (5, int | str, True), - ("hello", int | str, True), - (5.0, int | str, False), - (None, typing.Union[int, type(None)], True), - (None, typing.Union[int, str], False), - (None, int | str, False), - (42, int | None, True), - ("hello", int | None, False), - (3.14, int | None, False), - ([1], typing.List[int] | None, True), - (None, int | None, True), - (None, str | None, True), - (None, None | str, True), - (None, None | int, True), - (None, str | int, False), - (None, None | typing.List[typing.Dict[str, int]], True), + (5, int | str, True), # type: ignore[operator] + ("hello", int | str, True), # type: ignore[operator] + (5.0, int | str, False), # type: ignore[operator] + (None, typing.Union[int, type(None)], True), # type: ignore[operator] + (None, typing.Union[int, str], False), # type: ignore[operator] + (None, int | str, False), # type: ignore[operator] + (42, int | None, True), # type: ignore[operator] + ("hello", int | None, False), # type: ignore[operator] + (3.14, int | None, False), # type: ignore[operator] + ([1], typing.List[int] | None, True), # type: ignore[operator] + (None, int | None, True), # type: ignore[operator] + (None, str | None, True), # type: ignore[operator] + (None, None | str, True), # type: ignore[operator] + (None, None | int, True), # type: ignore[operator] + (None, str | int, False), # type: ignore[operator] + (None, None | typing.List[typing.Dict[str, int]], True), # type: ignore[operator] ], ) def test_validate_type_union(value, expected_type, expected_result): From 200cb3167afbd7256165172d1f3fae3331099bb7 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 12:14:22 -0700 Subject: [PATCH 113/158] wip switching to ruff for formatting --- makefile | 16 +++++++++++++--- pyproject.toml | 7 ++++--- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/makefile b/makefile index 8ff3046f..337efdb0 100644 --- a/makefile +++ b/makefile @@ -45,7 +45,7 @@ else PYTHON = $(PYTHON_BASE) endif -PYTHON_VERSION := $(shell $(PYTHON) -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')") +PYTHON_VERSION := $(shell $(PYTHON) -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}')") COMPATIBILITY_MODE := $(shell $(PYTHON) -c "import sys; print(1 if sys.version_info < (3, 10) else 0)") @@ -53,8 +53,16 @@ TYPECHECK_ARGS ?= # formatting # -------------------------------------------------- + +.PHONY: setup-format +setup-format: + @echo "install only packages needed for formatting, direct via pip (useful for CI)" + $(PYTHON_BASE) -c 'import re,tomllib; cfg = tomllib.load(open("$(PYPROJECT)", "rb")); deps = [(pkg, re.match(r"^\D*(\d.*)", ver).group(1)) for pkg, ver in cfg["tool"]["poetry"]["group"]["dev"]["dependencies"].items() if pkg in ["ruff", "pycln"]]; print(" ".join([f"{pkg}=={ver}" for pkg,ver in deps]))' | xargs $(PYTHON) -m pip install + .PHONY: format format: + @echo "format the source code" + $(PYTHON) -m ruff format $(PYTHON) -m pycln --config $(PYPROJECT) --all . $(PYTHON) -m isort format . $(PYTHON) -m black . @@ -62,6 +70,7 @@ format: .PHONY: check-format check-format: @echo "run format check" + $(PYTHON) -m ruff check $(PYTHON) -m pycln --check --config $(PYPROJECT) . $(PYTHON) -m isort --check-only . $(PYTHON) -m black --check . @@ -199,14 +208,15 @@ publish: check build verify-git version clean: @echo "cleaning up" rm -rf .mypy_cache + rm -rf .ruff_cache rm -rf .pytest_cache rm -rf .coverage rm -rf dist rm -rf build rm -rf $(PACKAGE_NAME).egg-info rm -rf tests/junk_data - python -Bc "import pathlib; [p.unlink() for p in pathlib.Path('.').rglob('*.py[co]')]" - python -Bc "import pathlib; [p.rmdir() for p in pathlib.Path('.').rglob('__pycache__')]" + $(PYTHON_BASE) -Bc "import pathlib; [p.unlink() for p in pathlib.Path('.').rglob('*.py[co]')]" + $(PYTHON_BASE) -Bc "import pathlib; [p.rmdir() for p in pathlib.Path('.').rglob('__pycache__')]" rm -rf tests/unit/validate_type/test_validate_type_MODERN.py # listing targets, from stackoverflow diff --git a/pyproject.toml b/pyproject.toml index 8c1ee29f..2b230115 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,10 +33,11 @@ zanj = ["zanj"] [tool.poetry.group.dev.dependencies] pytest = "^8.2.2" -black = "^24.1.1" -pylint = "^2.16.4" +ruff = "^0.4.8" +# black = "^24.1.1" +# pylint = "^2.16.4" +# isort = "^5.12.0" pycln = "^2.1.3" -isort = "^5.12.0" mypy = "^1.0.1" pytest-cov = "^4.1.0" coverage-badge = "^1.1.0" From 32bbc79e466735e91adac1e6c7a28bdc09d8b3ef Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 14:09:07 -0700 Subject: [PATCH 114/158] update makefile, switch to ruff --- dev-requirements.txt | 18 +-- makefile | 159 ++++++++++++++-------- poetry.lock | 311 +++++-------------------------------------- pyproject.toml | 25 ++-- 4 files changed, 151 insertions(+), 362 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 6de99214..18b678be 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,15 +1,12 @@ -astroid==2.15.8 ; python_version >= "3.8" and python_version < "4.0" asttokens==2.4.1 ; python_version >= "3.10" and python_version < "4.0" -black==24.4.2 ; python_version >= "3.8" and python_version < "4.0" -click==8.1.7 ; python_version >= "3.8" and python_version < "4.0" -colorama==0.4.6 ; python_version >= "3.8" and python_version < "4.0" and (sys_platform == "win32" or platform_system == "Windows") +click==8.1.7 ; python_version >= "3.8" and python_version < "4" +colorama==0.4.6 ; python_version >= "3.8" and sys_platform == "win32" and python_version < "4.0" or python_version >= "3.8" and python_version < "4" and platform_system == "Windows" contourpy==1.1.1 ; python_version >= "3.8" and python_version < "4.0" coverage-badge==1.1.1 ; python_version >= "3.8" and python_version < "4.0" coverage==7.5.3 ; python_version >= "3.8" and python_version < "4.0" coverage[toml]==7.5.3 ; python_version >= "3.8" and python_version < "4.0" cycler==0.12.1 ; python_version >= "3.8" and python_version < "4.0" decorator==5.1.1 ; python_version >= "3.10" and python_version < "4.0" -dill==0.3.8 ; python_version >= "3.8" and python_version < "4.0" exceptiongroup==1.2.1 ; python_version >= "3.8" and python_version < "3.11" executing==2.0.1 ; python_version >= "3.10" and python_version < "4.0" fonttools==4.53.0 ; python_version >= "3.8" and python_version < "4.0" @@ -17,26 +14,22 @@ importlib-metadata==7.1.0 ; python_version >= "3.8" and python_version < "3.10" importlib-resources==6.4.0 ; python_version >= "3.8" and python_version < "3.10" iniconfig==2.0.0 ; python_version >= "3.8" and python_version < "4.0" ipython==8.25.0 ; python_version >= "3.10" and python_version < "4.0" -isort==5.13.2 ; python_version >= "3.8" and python_version < "4.0" jaxtyping==0.2.19 ; python_version >= "3.8" and python_version < "4.0" jedi==0.19.1 ; python_version >= "3.10" and python_version < "4.0" kiwisolver==1.4.5 ; python_version >= "3.8" and python_version < "4.0" -lazy-object-proxy==1.10.0 ; python_version >= "3.8" and python_version < "4.0" libcst==1.1.0 ; python_version >= "3.8" and python_version < "4" markdown-it-py==3.0.0 ; python_version >= "3.8" and python_version < "4" matplotlib-inline==0.1.7 ; python_version >= "3.10" and python_version < "4.0" matplotlib==3.7.5 ; python_version >= "3.8" and python_version < "4.0" -mccabe==0.7.0 ; python_version >= "3.8" and python_version < "4.0" mdurl==0.1.2 ; python_version >= "3.8" and python_version < "4" mypy-extensions==1.0.0 ; python_version >= "3.8" and python_version < "4.0" mypy==1.10.0 ; python_version >= "3.8" and python_version < "4.0" numpy==1.24.4 ; python_version >= "3.8" and python_version < "4.0" packaging==24.1 ; python_version >= "3.8" and python_version < "4.0" parso==0.8.4 ; python_version >= "3.10" and python_version < "4.0" -pathspec==0.12.1 ; python_version >= "3.8" and python_version < "4.0" +pathspec==0.12.1 ; python_version >= "3.8" and python_version < "4" pexpect==4.9.0 ; python_version >= "3.10" and python_version < "4.0" and (sys_platform != "win32" and sys_platform != "emscripten") pillow==10.3.0 ; python_version >= "3.8" and python_version < "4.0" -platformdirs==4.2.2 ; python_version >= "3.8" and python_version < "4.0" plotly==5.22.0 ; python_version >= "3.8" and python_version < "4.0" pluggy==1.5.0 ; python_version >= "3.8" and python_version < "4.0" prompt-toolkit==3.0.47 ; python_version >= "3.10" and python_version < "4.0" @@ -44,24 +37,23 @@ ptyprocess==0.7.0 ; python_version >= "3.10" and python_version < "4.0" and (sys pure-eval==0.2.2 ; python_version >= "3.10" and python_version < "4.0" pycln==2.4.0 ; python_version >= "3.8" and python_version < "4" pygments==2.18.0 ; python_version >= "3.8" and python_version < "4.0" -pylint==2.17.7 ; python_version >= "3.8" and python_version < "4.0" pyparsing==3.1.2 ; python_version >= "3.8" and python_version < "4.0" pytest-cov==4.1.0 ; python_version >= "3.8" and python_version < "4.0" pytest==8.2.2 ; python_version >= "3.8" and python_version < "4.0" python-dateutil==2.9.0.post0 ; python_version >= "3.8" and python_version < "4.0" pyyaml==6.0.1 ; python_version >= "3.8" and python_version < "4" rich==13.7.1 ; python_version >= "3.8" and python_version < "4" +ruff==0.4.9 ; python_version >= "3.8" and python_version < "4.0" shellingham==1.5.4 ; python_version >= "3.8" and python_version < "4" six==1.16.0 ; python_version >= "3.8" and python_version < "4.0" stack-data==0.6.3 ; python_version >= "3.10" and python_version < "4.0" tenacity==8.4.1 ; python_version >= "3.8" and python_version < "4.0" tomli==2.0.1 ; python_version >= "3.8" and python_full_version <= "3.11.0a6" -tomlkit==0.12.5 ; python_version >= "3.8" and python_version < "4.0" +tomlkit==0.12.5 ; python_version >= "3.8" and python_version < "4" traitlets==5.14.3 ; python_version >= "3.10" and python_version < "4.0" typeguard==4.3.0 ; python_version >= "3.8" and python_version < "4.0" typer==0.12.3 ; python_version >= "3.8" and python_version < "4" typing-extensions==4.12.2 ; python_version >= "3.8" and python_version < "4.0" typing-inspect==0.9.0 ; python_version >= "3.8" and python_version < "4" wcwidth==0.2.13 ; python_version >= "3.10" and python_version < "4.0" -wrapt==1.16.0 ; python_version >= "3.8" and python_version < "4.0" zipp==3.19.2 ; python_version >= "3.8" and python_version < "3.10" diff --git a/makefile b/makefile index 337efdb0..51fa3c3b 100644 --- a/makefile +++ b/makefile @@ -1,43 +1,63 @@ +# configuration +# ================================================== +# MODIFY THIS FILE TO SUIT YOUR PROJECT +# it assumes that the source is in a directory named the same as the package name PACKAGE_NAME := muutils +# for checking you are on the right branch when publishing PUBLISH_BRANCH := main -PYPI_TOKEN_FILE := .pypi-token -LAST_VERSION_FILE := .lastversion +# where to put the coverage reports COVERAGE_REPORTS_DIR := docs/coverage +# where the tests are (assumes pytest) TESTS_DIR := tests/unit +# temp directory to clean up +TESTS_TEMP_DIR := tests/_temp + +# probably don't change these: +# -------------------------------------------------- +# will print this token when publishing +PYPI_TOKEN_FILE := .pypi-token +# the last version that was auto-uploaded. will use this to create a commit log for version tag +LAST_VERSION_FILE := .lastversion +# where the pyproject.toml file is PYPROJECT := pyproject.toml +# base python to use. Will add `poetry run` in front of this if `RUN_GLOBAL` is not set to 1 +PYTHON_BASE := python +# where the commit log will be stored +COMMIT_LOG_FILE := .commit_log + + +# reading information and command line options +# ================================================== + +# reading version +# -------------------------------------------------- +# assuming your pyproject.toml has a line that looks like `version = "0.0.1"`, will get the version VERSION := $(shell python -c "import re; print(re.search(r'^version\s*=\s*\"(.+?)\"', open('$(PYPROJECT)').read(), re.MULTILINE).group(1))") -LAST_VERSION := $(shell cat $(LAST_VERSION_FILE)) -PYTHON_BASE := python +# read last auto-uploaded version from file +LAST_VERSION := $(shell [ -f $(LAST_VERSION_FILE) ] && cat $(LAST_VERSION_FILE) || echo NONE) + +# getting commit log +# -------------------------------------------------- # note that the commands at the end: # 1) format the git log # 2) replace backticks with single quotes, to avoid funny business # 3) add a final newline, to make tac happy # 4) reverse the order of the lines, so that the oldest commit is first # 5) replace newlines with tabs, to prevent the newlines from being lost -COMMIT_LOG_FILE := .commit_log -COMMIT_LOG_SINCE_LAST_VERSION := $(shell (git log $(LAST_VERSION)..HEAD --pretty=format:"- %s (%h)" | tr '`' "'" ; echo) | tac | tr '\n' '\t') +ifeq ($(LAST_VERSION),NONE) + COMMIT_LOG_SINCE_LAST_VERSION := "No last version found, cannot generate commit log" +else + COMMIT_LOG_SINCE_LAST_VERSION := $(shell (git log $(LAST_VERSION)..HEAD --pretty=format:"- %s (%h)" | tr '`' "'" ; echo) | tac | tr '\n' '\t') # 1 2 3 4 5 +endif -.PHONY: default -default: help -.PHONY: version -version: - @echo "Current version is $(VERSION), last auto-uploaded version is $(LAST_VERSION)" - @echo "Commit log since last version:" - @echo "$(COMMIT_LOG_SINCE_LAST_VERSION)" | tr '\t' '\n' > $(COMMIT_LOG_FILE) - @cat $(COMMIT_LOG_FILE) - @if [ "$(VERSION)" = "$(LAST_VERSION)" ]; then \ - echo "Python package $(VERSION) is the same as last published version $(LAST_VERSION), exiting!"; \ - exit 1; \ - fi - -# command line options +# RUN_GLOBAL=1 to use global `PYTHON_BASE` instead of `poetry run $(PYTHON_BASE)` # -------------------------------------------------- -# for formatting or CI, we might want to run python without setting up all of poetry +# for formatting, we might want to run python without setting up all of poetry RUN_GLOBAL ?= 0 ifeq ($(RUN_GLOBAL),0) PYTHON = poetry run $(PYTHON_BASE) @@ -45,48 +65,31 @@ else PYTHON = $(PYTHON_BASE) endif +# get the python version now that we have picked the python command +# -------------------------------------------------- PYTHON_VERSION := $(shell $(PYTHON) -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}')") +# looser typing, allow warnings for python <3.10 +# -------------------------------------------------- COMPATIBILITY_MODE := $(shell $(PYTHON) -c "import sys; print(1 if sys.version_info < (3, 10) else 0)") - TYPECHECK_ARGS ?= -# formatting +# options we might want to pass to pytest # -------------------------------------------------- +PYTEST_OPTIONS ?= # using ?= means you can pass extra options from the command line +COV ?= 1 -.PHONY: setup-format -setup-format: - @echo "install only packages needed for formatting, direct via pip (useful for CI)" - $(PYTHON_BASE) -c 'import re,tomllib; cfg = tomllib.load(open("$(PYPROJECT)", "rb")); deps = [(pkg, re.match(r"^\D*(\d.*)", ver).group(1)) for pkg, ver in cfg["tool"]["poetry"]["group"]["dev"]["dependencies"].items() if pkg in ["ruff", "pycln"]]; print(" ".join([f"{pkg}=={ver}" for pkg,ver in deps]))' | xargs $(PYTHON) -m pip install - -.PHONY: format -format: - @echo "format the source code" - $(PYTHON) -m ruff format - $(PYTHON) -m pycln --config $(PYPROJECT) --all . - $(PYTHON) -m isort format . - $(PYTHON) -m black . +ifdef VERBOSE + PYTEST_OPTIONS += --verbose +endif -.PHONY: check-format -check-format: - @echo "run format check" - $(PYTHON) -m ruff check - $(PYTHON) -m pycln --check --config $(PYPROJECT) . - $(PYTHON) -m isort --check-only . - $(PYTHON) -m black --check . +ifeq ($(COV),1) + PYTEST_OPTIONS += --cov=. +endif -# pytest options and coverage +# compatibility mode for python <3.10 # -------------------------------------------------- -PYTEST_OPTIONS ?= - -# whether to run pytest with coverage report generation -COV ?= 1 - -ifneq ($(COV), 0) - PYTEST_OPTIONS += --cov=. -endif - # whether to run pytest with warnings as errors WARN_STRICT ?= 0 @@ -94,8 +97,6 @@ ifneq ($(WARN_STRICT), 0) PYTEST_OPTIONS += -W error endif -# compatibility mode for python <3.10 - # Update the PYTEST_OPTIONS to include the conditional ignore option ifeq ($(COMPATIBILITY_MODE), 1) JUNK := $(info WARNING: Detected python version less than 3.10, some behavior will be different) @@ -103,6 +104,48 @@ ifeq ($(COMPATIBILITY_MODE), 1) TYPECHECK_ARGS += --disable-error-code misc --disable-error-code syntax --disable-error-code import-not-found endif + +# default target (help) +# ================================================== + +.PHONY: default +default: help + +.PHONY: version +version: + @echo "Current version is $(VERSION), last auto-uploaded version is $(LAST_VERSION)" + @echo "Commit log since last version:" + @echo "$(COMMIT_LOG_SINCE_LAST_VERSION)" | tr '\t' '\n' > $(COMMIT_LOG_FILE) + @cat $(COMMIT_LOG_FILE) + @if [ "$(VERSION)" = "$(LAST_VERSION)" ]; then \ + echo "Python package $(VERSION) is the same as last published version $(LAST_VERSION), exiting!"; \ + exit 1; \ + fi + + +# formatting +# ================================================== + +.PHONY: setup-format +setup-format: + @echo "install only packages needed for formatting, direct via pip (useful for CI)" + $(PYTHON_BASE) -c 'import re,tomllib; cfg = tomllib.load(open("$(PYPROJECT)", "rb")); deps = [(pkg, re.match(r"^\D*(\d.*)", ver).group(1)) for pkg, ver in cfg["tool"]["poetry"]["group"]["dev"]["dependencies"].items() if pkg in ["ruff", "pycln"]]; print(" ".join([f"{pkg}=={ver}" for pkg,ver in deps]))' | xargs $(PYTHON) -m pip install + +.PHONY: format +format: + @echo "format the source code" + $(PYTHON) -m ruff format --config $(PYPROJECT) . + $(PYTHON) -m pycln --config $(PYPROJECT) --all . + +.PHONY: check-format +check-format: + @echo "run format check" + $(PYTHON) -m ruff check --config $(PYPROJECT) . + $(PYTHON) -m pycln --check --config $(PYPROJECT) . + +# coverage +# ================================================== + .PHONY: cov cov: @echo "generate coverage reports" @@ -111,7 +154,7 @@ cov: $(PYTHON) -m coverage html # tests -# -------------------------------------------------- +# ================================================== # at some point, need to add back --check-untyped-defs to mypy call # but it complains when we specify arguments by keyword where positional is fine @@ -141,7 +184,7 @@ check: clean check-format clean test lint @echo "run format check, test, and lint" # build and publish -# -------------------------------------------------- +# ================================================== .PHONY: verify-git verify-git: @@ -202,7 +245,7 @@ publish: check build verify-git version twine upload dist/* --verbose # cleanup -# -------------------------------------------------- +# ================================================== .PHONY: clean clean: @@ -223,7 +266,7 @@ clean: # https://stackoverflow.com/questions/4219255/how-do-you-get-the-list-of-targets-in-a-makefile .PHONY: help help: - @echo -n "# list make targets" + @echo -n "list make targets" @echo ":" @cat Makefile | sed -n '/^\.PHONY: / h; /\(^\t@*echo\|^\t:\)/ {H; x; /PHONY/ s/.PHONY: \(.*\)\n.*"\(.*\)"/ make \1\t\2/p; d; x}'| sort -k2,2 |expand -t 25 @echo "# makefile variables:" diff --git a/poetry.lock b/poetry.lock index c5d8915a..b858b705 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,24 +1,5 @@ # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. -[[package]] -name = "astroid" -version = "2.15.8" -description = "An abstract syntax tree for Python with inference support." -optional = false -python-versions = ">=3.7.2" -files = [ - {file = "astroid-2.15.8-py3-none-any.whl", hash = "sha256:1aa149fc5c6589e3d0ece885b4491acd80af4f087baafa3fb5203b113e68cd3c"}, - {file = "astroid-2.15.8.tar.gz", hash = "sha256:6c107453dffee9055899705de3c9ead36e74119cee151e5a9aaf7f0b0e020a6a"}, -] - -[package.dependencies] -lazy-object-proxy = ">=1.4.0" -typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""} -wrapt = [ - {version = ">=1.11,<2", markers = "python_version < \"3.11\""}, - {version = ">=1.14,<2", markers = "python_version >= \"3.11\""}, -] - [[package]] name = "asttokens" version = "2.4.1" @@ -37,52 +18,6 @@ six = ">=1.12.0" astroid = ["astroid (>=1,<2)", "astroid (>=2,<4)"] test = ["astroid (>=1,<2)", "astroid (>=2,<4)", "pytest"] -[[package]] -name = "black" -version = "24.4.2" -description = "The uncompromising code formatter." -optional = false -python-versions = ">=3.8" -files = [ - {file = "black-24.4.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dd1b5a14e417189db4c7b64a6540f31730713d173f0b63e55fabd52d61d8fdce"}, - {file = "black-24.4.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8e537d281831ad0e71007dcdcbe50a71470b978c453fa41ce77186bbe0ed6021"}, - {file = "black-24.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaea3008c281f1038edb473c1aa8ed8143a5535ff18f978a318f10302b254063"}, - {file = "black-24.4.2-cp310-cp310-win_amd64.whl", hash = "sha256:7768a0dbf16a39aa5e9a3ded568bb545c8c2727396d063bbaf847df05b08cd96"}, - {file = "black-24.4.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:257d724c2c9b1660f353b36c802ccece186a30accc7742c176d29c146df6e474"}, - {file = "black-24.4.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bdde6f877a18f24844e381d45e9947a49e97933573ac9d4345399be37621e26c"}, - {file = "black-24.4.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e151054aa00bad1f4e1f04919542885f89f5f7d086b8a59e5000e6c616896ffb"}, - {file = "black-24.4.2-cp311-cp311-win_amd64.whl", hash = "sha256:7e122b1c4fb252fd85df3ca93578732b4749d9be076593076ef4d07a0233c3e1"}, - {file = "black-24.4.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:accf49e151c8ed2c0cdc528691838afd217c50412534e876a19270fea1e28e2d"}, - {file = "black-24.4.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:88c57dc656038f1ab9f92b3eb5335ee9b021412feaa46330d5eba4e51fe49b04"}, - {file = "black-24.4.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:be8bef99eb46d5021bf053114442914baeb3649a89dc5f3a555c88737e5e98fc"}, - {file = "black-24.4.2-cp312-cp312-win_amd64.whl", hash = "sha256:415e686e87dbbe6f4cd5ef0fbf764af7b89f9057b97c908742b6008cc554b9c0"}, - {file = "black-24.4.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bf10f7310db693bb62692609b397e8d67257c55f949abde4c67f9cc574492cc7"}, - {file = "black-24.4.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:98e123f1d5cfd42f886624d84464f7756f60ff6eab89ae845210631714f6db94"}, - {file = "black-24.4.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:48a85f2cb5e6799a9ef05347b476cce6c182d6c71ee36925a6c194d074336ef8"}, - {file = "black-24.4.2-cp38-cp38-win_amd64.whl", hash = "sha256:b1530ae42e9d6d5b670a34db49a94115a64596bc77710b1d05e9801e62ca0a7c"}, - {file = "black-24.4.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:37aae07b029fa0174d39daf02748b379399b909652a806e5708199bd93899da1"}, - {file = "black-24.4.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:da33a1a5e49c4122ccdfd56cd021ff1ebc4a1ec4e2d01594fef9b6f267a9e741"}, - {file = "black-24.4.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef703f83fc32e131e9bcc0a5094cfe85599e7109f896fe8bc96cc402f3eb4b6e"}, - {file = "black-24.4.2-cp39-cp39-win_amd64.whl", hash = "sha256:b9176b9832e84308818a99a561e90aa479e73c523b3f77afd07913380ae2eab7"}, - {file = "black-24.4.2-py3-none-any.whl", hash = "sha256:d36ed1124bb81b32f8614555b34cc4259c3fbc7eec17870e8ff8ded335b58d8c"}, - {file = "black-24.4.2.tar.gz", hash = "sha256:c872b53057f000085da66a19c55d68f6f8ddcac2642392ad3a355878406fbd4d"}, -] - -[package.dependencies] -click = ">=8.0.0" -mypy-extensions = ">=0.4.3" -packaging = ">=22.0" -pathspec = ">=0.9.0" -platformdirs = ">=2" -tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} -typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""} - -[package.extras] -colorama = ["colorama (>=0.4.3)"] -d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"] -jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] -uvloop = ["uvloop (>=0.15.2)"] - [[package]] name = "click" version = "8.1.7" @@ -348,21 +283,6 @@ files = [ {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, ] -[[package]] -name = "dill" -version = "0.3.8" -description = "serialize all of Python" -optional = false -python-versions = ">=3.8" -files = [ - {file = "dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7"}, - {file = "dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca"}, -] - -[package.extras] -graph = ["objgraph (>=1.7.2)"] -profile = ["gprof2dot (>=2022.7.29)"] - [[package]] name = "exceptiongroup" version = "1.2.1" @@ -393,18 +313,18 @@ tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipyth [[package]] name = "filelock" -version = "3.15.1" +version = "3.15.3" description = "A platform independent file lock." optional = true python-versions = ">=3.8" files = [ - {file = "filelock-3.15.1-py3-none-any.whl", hash = "sha256:71b3102950e91dfc1bb4209b64be4dc8854f40e5f534428d8684f953ac847fac"}, - {file = "filelock-3.15.1.tar.gz", hash = "sha256:58a2549afdf9e02e10720eaa4d4470f56386d7a6f72edd7d0596337af8ed7ad8"}, + {file = "filelock-3.15.3-py3-none-any.whl", hash = "sha256:0151273e5b5d6cf753a61ec83b3a9b7d8821c39ae9af9d7ecf2f9e2f17404103"}, + {file = "filelock-3.15.3.tar.gz", hash = "sha256:e1199bf5194a2277273dacd50269f0d87d0682088a3c561c15674ea9005d8635"}, ] [package.extras] docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"] typing = ["typing-extensions (>=4.8)"] [[package]] @@ -611,20 +531,6 @@ qtconsole = ["qtconsole"] test = ["pickleshare", "pytest", "pytest-asyncio (<0.22)", "testpath"] test-extra = ["curio", "ipython[test]", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.23)", "pandas", "trio"] -[[package]] -name = "isort" -version = "5.13.2" -description = "A Python utility / library to sort Python imports." -optional = false -python-versions = ">=3.8.0" -files = [ - {file = "isort-5.13.2-py3-none-any.whl", hash = "sha256:8ca5e72a8d85860d5a3fa69b8745237f2939afe12dbf656afbcb47fe72d947a6"}, - {file = "isort-5.13.2.tar.gz", hash = "sha256:48fdfcb9face5d58a4f6dde2e72a1fb8dcaf8ab26f95ab49fab84c2ddefb0109"}, -] - -[package.extras] -colors = ["colorama (>=0.4.6)"] - [[package]] name = "jaxtyping" version = "0.2.19" @@ -790,52 +696,6 @@ files = [ {file = "kiwisolver-1.4.5.tar.gz", hash = "sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec"}, ] -[[package]] -name = "lazy-object-proxy" -version = "1.10.0" -description = "A fast and thorough lazy object proxy." -optional = false -python-versions = ">=3.8" -files = [ - {file = "lazy-object-proxy-1.10.0.tar.gz", hash = "sha256:78247b6d45f43a52ef35c25b5581459e85117225408a4128a3daf8bf9648ac69"}, - {file = "lazy_object_proxy-1.10.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:855e068b0358ab916454464a884779c7ffa312b8925c6f7401e952dcf3b89977"}, - {file = "lazy_object_proxy-1.10.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab7004cf2e59f7c2e4345604a3e6ea0d92ac44e1c2375527d56492014e690c3"}, - {file = "lazy_object_proxy-1.10.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dc0d2fc424e54c70c4bc06787e4072c4f3b1aa2f897dfdc34ce1013cf3ceef05"}, - {file = "lazy_object_proxy-1.10.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e2adb09778797da09d2b5ebdbceebf7dd32e2c96f79da9052b2e87b6ea495895"}, - {file = "lazy_object_proxy-1.10.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b1f711e2c6dcd4edd372cf5dec5c5a30d23bba06ee012093267b3376c079ec83"}, - {file = "lazy_object_proxy-1.10.0-cp310-cp310-win32.whl", hash = "sha256:76a095cfe6045c7d0ca77db9934e8f7b71b14645f0094ffcd842349ada5c5fb9"}, - {file = "lazy_object_proxy-1.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:b4f87d4ed9064b2628da63830986c3d2dca7501e6018347798313fcf028e2fd4"}, - {file = "lazy_object_proxy-1.10.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:fec03caabbc6b59ea4a638bee5fce7117be8e99a4103d9d5ad77f15d6f81020c"}, - {file = "lazy_object_proxy-1.10.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:02c83f957782cbbe8136bee26416686a6ae998c7b6191711a04da776dc9e47d4"}, - {file = "lazy_object_proxy-1.10.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:009e6bb1f1935a62889ddc8541514b6a9e1fcf302667dcb049a0be5c8f613e56"}, - {file = "lazy_object_proxy-1.10.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:75fc59fc450050b1b3c203c35020bc41bd2695ed692a392924c6ce180c6f1dc9"}, - {file = "lazy_object_proxy-1.10.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:782e2c9b2aab1708ffb07d4bf377d12901d7a1d99e5e410d648d892f8967ab1f"}, - {file = "lazy_object_proxy-1.10.0-cp311-cp311-win32.whl", hash = "sha256:edb45bb8278574710e68a6b021599a10ce730d156e5b254941754a9cc0b17d03"}, - {file = "lazy_object_proxy-1.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:e271058822765ad5e3bca7f05f2ace0de58a3f4e62045a8c90a0dfd2f8ad8cc6"}, - {file = "lazy_object_proxy-1.10.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e98c8af98d5707dcdecc9ab0863c0ea6e88545d42ca7c3feffb6b4d1e370c7ba"}, - {file = "lazy_object_proxy-1.10.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:952c81d415b9b80ea261d2372d2a4a2332a3890c2b83e0535f263ddfe43f0d43"}, - {file = "lazy_object_proxy-1.10.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80b39d3a151309efc8cc48675918891b865bdf742a8616a337cb0090791a0de9"}, - {file = "lazy_object_proxy-1.10.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e221060b701e2aa2ea991542900dd13907a5c90fa80e199dbf5a03359019e7a3"}, - {file = "lazy_object_proxy-1.10.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:92f09ff65ecff3108e56526f9e2481b8116c0b9e1425325e13245abfd79bdb1b"}, - {file = "lazy_object_proxy-1.10.0-cp312-cp312-win32.whl", hash = "sha256:3ad54b9ddbe20ae9f7c1b29e52f123120772b06dbb18ec6be9101369d63a4074"}, - {file = "lazy_object_proxy-1.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:127a789c75151db6af398b8972178afe6bda7d6f68730c057fbbc2e96b08d282"}, - {file = "lazy_object_proxy-1.10.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9e4ed0518a14dd26092614412936920ad081a424bdcb54cc13349a8e2c6d106a"}, - {file = "lazy_object_proxy-1.10.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5ad9e6ed739285919aa9661a5bbed0aaf410aa60231373c5579c6b4801bd883c"}, - {file = "lazy_object_proxy-1.10.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2fc0a92c02fa1ca1e84fc60fa258458e5bf89d90a1ddaeb8ed9cc3147f417255"}, - {file = "lazy_object_proxy-1.10.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:0aefc7591920bbd360d57ea03c995cebc204b424524a5bd78406f6e1b8b2a5d8"}, - {file = "lazy_object_proxy-1.10.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5faf03a7d8942bb4476e3b62fd0f4cf94eaf4618e304a19865abf89a35c0bbee"}, - {file = "lazy_object_proxy-1.10.0-cp38-cp38-win32.whl", hash = "sha256:e333e2324307a7b5d86adfa835bb500ee70bfcd1447384a822e96495796b0ca4"}, - {file = "lazy_object_proxy-1.10.0-cp38-cp38-win_amd64.whl", hash = "sha256:cb73507defd385b7705c599a94474b1d5222a508e502553ef94114a143ec6696"}, - {file = "lazy_object_proxy-1.10.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:366c32fe5355ef5fc8a232c5436f4cc66e9d3e8967c01fb2e6302fd6627e3d94"}, - {file = "lazy_object_proxy-1.10.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2297f08f08a2bb0d32a4265e98a006643cd7233fb7983032bd61ac7a02956b3b"}, - {file = "lazy_object_proxy-1.10.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18dd842b49456aaa9a7cf535b04ca4571a302ff72ed8740d06b5adcd41fe0757"}, - {file = "lazy_object_proxy-1.10.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:217138197c170a2a74ca0e05bddcd5f1796c735c37d0eee33e43259b192aa424"}, - {file = "lazy_object_proxy-1.10.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:9a3a87cf1e133e5b1994144c12ca4aa3d9698517fe1e2ca82977781b16955658"}, - {file = "lazy_object_proxy-1.10.0-cp39-cp39-win32.whl", hash = "sha256:30b339b2a743c5288405aa79a69e706a06e02958eab31859f7f3c04980853b70"}, - {file = "lazy_object_proxy-1.10.0-cp39-cp39-win_amd64.whl", hash = "sha256:a899b10e17743683b293a729d3a11f2f399e8a90c73b089e29f5d0fe3509f0dd"}, - {file = "lazy_object_proxy-1.10.0-pp310.pp311.pp312.pp38.pp39-none-any.whl", hash = "sha256:80fa48bd89c8f2f456fc0765c11c23bf5af827febacd2f523ca5bc1893fcc09d"}, -] - [[package]] name = "libcst" version = "1.1.0" @@ -1059,17 +919,6 @@ files = [ [package.dependencies] traitlets = "*" -[[package]] -name = "mccabe" -version = "0.7.0" -description = "McCabe checker, plugin for flake8" -optional = false -python-versions = ">=3.6" -files = [ - {file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"}, - {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, -] - [[package]] name = "mdurl" version = "0.1.2" @@ -1508,22 +1357,6 @@ tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "pa typing = ["typing-extensions"] xmp = ["defusedxml"] -[[package]] -name = "platformdirs" -version = "4.2.2" -description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." -optional = false -python-versions = ">=3.8" -files = [ - {file = "platformdirs-4.2.2-py3-none-any.whl", hash = "sha256:2d7a1657e36a80ea911db832a8a6ece5ee53d8de21edd5cc5879af6530b1bfee"}, - {file = "platformdirs-4.2.2.tar.gz", hash = "sha256:38b7b51f512eed9e84a22788b4bce1de17c0adb134d6becb09836e37d8654cd3"}, -] - -[package.extras] -docs = ["furo (>=2023.9.10)", "proselint (>=0.13)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] -test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)"] -type = ["mypy (>=1.8)"] - [[package]] name = "plotly" version = "5.22.0" @@ -1625,35 +1458,6 @@ files = [ [package.extras] windows-terminal = ["colorama (>=0.4.6)"] -[[package]] -name = "pylint" -version = "2.17.7" -description = "python code static checker" -optional = false -python-versions = ">=3.7.2" -files = [ - {file = "pylint-2.17.7-py3-none-any.whl", hash = "sha256:27a8d4c7ddc8c2f8c18aa0050148f89ffc09838142193fdbe98f172781a3ff87"}, - {file = "pylint-2.17.7.tar.gz", hash = "sha256:f4fcac7ae74cfe36bc8451e931d8438e4a476c20314b1101c458ad0f05191fad"}, -] - -[package.dependencies] -astroid = ">=2.15.8,<=2.17.0-dev0" -colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} -dill = [ - {version = ">=0.2", markers = "python_version < \"3.11\""}, - {version = ">=0.3.6", markers = "python_version >= \"3.11\""}, -] -isort = ">=4.2.5,<6" -mccabe = ">=0.6,<0.8" -platformdirs = ">=2.2.0" -tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} -tomlkit = ">=0.10.1" -typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""} - -[package.extras] -spelling = ["pyenchant (>=3.2,<4.0)"] -testutils = ["gitpython (>3)"] - [[package]] name = "pyparsing" version = "3.1.2" @@ -1800,6 +1604,32 @@ typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.9 [package.extras] jupyter = ["ipywidgets (>=7.5.1,<9)"] +[[package]] +name = "ruff" +version = "0.4.9" +description = "An extremely fast Python linter and code formatter, written in Rust." +optional = false +python-versions = ">=3.7" +files = [ + {file = "ruff-0.4.9-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b262ed08d036ebe162123170b35703aaf9daffecb698cd367a8d585157732991"}, + {file = "ruff-0.4.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:98ec2775fd2d856dc405635e5ee4ff177920f2141b8e2d9eb5bd6efd50e80317"}, + {file = "ruff-0.4.9-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4555056049d46d8a381f746680db1c46e67ac3b00d714606304077682832998e"}, + {file = "ruff-0.4.9-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e91175fbe48f8a2174c9aad70438fe9cb0a5732c4159b2a10a3565fea2d94cde"}, + {file = "ruff-0.4.9-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0e8e7b95673f22e0efd3571fb5b0cf71a5eaaa3cc8a776584f3b2cc878e46bff"}, + {file = "ruff-0.4.9-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:2d45ddc6d82e1190ea737341326ecbc9a61447ba331b0a8962869fcada758505"}, + {file = "ruff-0.4.9-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:78de3fdb95c4af084087628132336772b1c5044f6e710739d440fc0bccf4d321"}, + {file = "ruff-0.4.9-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:06b60f91bfa5514bb689b500a25ba48e897d18fea14dce14b48a0c40d1635893"}, + {file = "ruff-0.4.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:88bffe9c6a454bf8529f9ab9091c99490578a593cc9f9822b7fc065ee0712a06"}, + {file = "ruff-0.4.9-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:673bddb893f21ab47a8334c8e0ea7fd6598ecc8e698da75bcd12a7b9d0a3206e"}, + {file = "ruff-0.4.9-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8c1aff58c31948cc66d0b22951aa19edb5af0a3af40c936340cd32a8b1ab7438"}, + {file = "ruff-0.4.9-py3-none-musllinux_1_2_i686.whl", hash = "sha256:784d3ec9bd6493c3b720a0b76f741e6c2d7d44f6b2be87f5eef1ae8cc1d54c84"}, + {file = "ruff-0.4.9-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:732dd550bfa5d85af8c3c6cbc47ba5b67c6aed8a89e2f011b908fc88f87649db"}, + {file = "ruff-0.4.9-py3-none-win32.whl", hash = "sha256:8064590fd1a50dcf4909c268b0e7c2498253273309ad3d97e4a752bb9df4f521"}, + {file = "ruff-0.4.9-py3-none-win_amd64.whl", hash = "sha256:e0a22c4157e53d006530c902107c7f550b9233e9706313ab57b892d7197d8e52"}, + {file = "ruff-0.4.9-py3-none-win_arm64.whl", hash = "sha256:5d5460f789ccf4efd43f265a58538a2c24dbce15dbf560676e430375f20a8198"}, + {file = "ruff-0.4.9.tar.gz", hash = "sha256:f1cb0828ac9533ba0135d148d214e284711ede33640465e706772645483427e3"}, +] + [[package]] name = "shellingham" version = "1.5.4" @@ -2070,85 +1900,6 @@ files = [ {file = "wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5"}, ] -[[package]] -name = "wrapt" -version = "1.16.0" -description = "Module for decorators, wrappers and monkey patching." -optional = false -python-versions = ">=3.6" -files = [ - {file = "wrapt-1.16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ffa565331890b90056c01db69c0fe634a776f8019c143a5ae265f9c6bc4bd6d4"}, - {file = "wrapt-1.16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e4fdb9275308292e880dcbeb12546df7f3e0f96c6b41197e0cf37d2826359020"}, - {file = "wrapt-1.16.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb2dee3874a500de01c93d5c71415fcaef1d858370d405824783e7a8ef5db440"}, - {file = "wrapt-1.16.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2a88e6010048489cda82b1326889ec075a8c856c2e6a256072b28eaee3ccf487"}, - {file = "wrapt-1.16.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac83a914ebaf589b69f7d0a1277602ff494e21f4c2f743313414378f8f50a4cf"}, - {file = "wrapt-1.16.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:73aa7d98215d39b8455f103de64391cb79dfcad601701a3aa0dddacf74911d72"}, - {file = "wrapt-1.16.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:807cc8543a477ab7422f1120a217054f958a66ef7314f76dd9e77d3f02cdccd0"}, - {file = "wrapt-1.16.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bf5703fdeb350e36885f2875d853ce13172ae281c56e509f4e6eca049bdfb136"}, - {file = "wrapt-1.16.0-cp310-cp310-win32.whl", hash = "sha256:f6b2d0c6703c988d334f297aa5df18c45e97b0af3679bb75059e0e0bd8b1069d"}, - {file = "wrapt-1.16.0-cp310-cp310-win_amd64.whl", hash = "sha256:decbfa2f618fa8ed81c95ee18a387ff973143c656ef800c9f24fb7e9c16054e2"}, - {file = "wrapt-1.16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1a5db485fe2de4403f13fafdc231b0dbae5eca4359232d2efc79025527375b09"}, - {file = "wrapt-1.16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:75ea7d0ee2a15733684badb16de6794894ed9c55aa5e9903260922f0482e687d"}, - {file = "wrapt-1.16.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a452f9ca3e3267cd4d0fcf2edd0d035b1934ac2bd7e0e57ac91ad6b95c0c6389"}, - {file = "wrapt-1.16.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:43aa59eadec7890d9958748db829df269f0368521ba6dc68cc172d5d03ed8060"}, - {file = "wrapt-1.16.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:72554a23c78a8e7aa02abbd699d129eead8b147a23c56e08d08dfc29cfdddca1"}, - {file = "wrapt-1.16.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d2efee35b4b0a347e0d99d28e884dfd82797852d62fcd7ebdeee26f3ceb72cf3"}, - {file = "wrapt-1.16.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:6dcfcffe73710be01d90cae08c3e548d90932d37b39ef83969ae135d36ef3956"}, - {file = "wrapt-1.16.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:eb6e651000a19c96f452c85132811d25e9264d836951022d6e81df2fff38337d"}, - {file = "wrapt-1.16.0-cp311-cp311-win32.whl", hash = "sha256:66027d667efe95cc4fa945af59f92c5a02c6f5bb6012bff9e60542c74c75c362"}, - {file = "wrapt-1.16.0-cp311-cp311-win_amd64.whl", hash = "sha256:aefbc4cb0a54f91af643660a0a150ce2c090d3652cf4052a5397fb2de549cd89"}, - {file = "wrapt-1.16.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5eb404d89131ec9b4f748fa5cfb5346802e5ee8836f57d516576e61f304f3b7b"}, - {file = "wrapt-1.16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9090c9e676d5236a6948330e83cb89969f433b1943a558968f659ead07cb3b36"}, - {file = "wrapt-1.16.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:94265b00870aa407bd0cbcfd536f17ecde43b94fb8d228560a1e9d3041462d73"}, - {file = "wrapt-1.16.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2058f813d4f2b5e3a9eb2eb3faf8f1d99b81c3e51aeda4b168406443e8ba809"}, - {file = "wrapt-1.16.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98b5e1f498a8ca1858a1cdbffb023bfd954da4e3fa2c0cb5853d40014557248b"}, - {file = "wrapt-1.16.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:14d7dc606219cdd7405133c713f2c218d4252f2a469003f8c46bb92d5d095d81"}, - {file = "wrapt-1.16.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:49aac49dc4782cb04f58986e81ea0b4768e4ff197b57324dcbd7699c5dfb40b9"}, - {file = "wrapt-1.16.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:418abb18146475c310d7a6dc71143d6f7adec5b004ac9ce08dc7a34e2babdc5c"}, - {file = "wrapt-1.16.0-cp312-cp312-win32.whl", hash = "sha256:685f568fa5e627e93f3b52fda002c7ed2fa1800b50ce51f6ed1d572d8ab3e7fc"}, - {file = "wrapt-1.16.0-cp312-cp312-win_amd64.whl", hash = "sha256:dcdba5c86e368442528f7060039eda390cc4091bfd1dca41e8046af7c910dda8"}, - {file = "wrapt-1.16.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:d462f28826f4657968ae51d2181a074dfe03c200d6131690b7d65d55b0f360f8"}, - {file = "wrapt-1.16.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a33a747400b94b6d6b8a165e4480264a64a78c8a4c734b62136062e9a248dd39"}, - {file = "wrapt-1.16.0-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b3646eefa23daeba62643a58aac816945cadc0afaf21800a1421eeba5f6cfb9c"}, - {file = "wrapt-1.16.0-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ebf019be5c09d400cf7b024aa52b1f3aeebeff51550d007e92c3c1c4afc2a40"}, - {file = "wrapt-1.16.0-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:0d2691979e93d06a95a26257adb7bfd0c93818e89b1406f5a28f36e0d8c1e1fc"}, - {file = "wrapt-1.16.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:1acd723ee2a8826f3d53910255643e33673e1d11db84ce5880675954183ec47e"}, - {file = "wrapt-1.16.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:bc57efac2da352a51cc4658878a68d2b1b67dbe9d33c36cb826ca449d80a8465"}, - {file = "wrapt-1.16.0-cp36-cp36m-win32.whl", hash = "sha256:da4813f751142436b075ed7aa012a8778aa43a99f7b36afe9b742d3ed8bdc95e"}, - {file = "wrapt-1.16.0-cp36-cp36m-win_amd64.whl", hash = "sha256:6f6eac2360f2d543cc875a0e5efd413b6cbd483cb3ad7ebf888884a6e0d2e966"}, - {file = "wrapt-1.16.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:a0ea261ce52b5952bf669684a251a66df239ec6d441ccb59ec7afa882265d593"}, - {file = "wrapt-1.16.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7bd2d7ff69a2cac767fbf7a2b206add2e9a210e57947dd7ce03e25d03d2de292"}, - {file = "wrapt-1.16.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9159485323798c8dc530a224bd3ffcf76659319ccc7bbd52e01e73bd0241a0c5"}, - {file = "wrapt-1.16.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a86373cf37cd7764f2201b76496aba58a52e76dedfaa698ef9e9688bfd9e41cf"}, - {file = "wrapt-1.16.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:73870c364c11f03ed072dda68ff7aea6d2a3a5c3fe250d917a429c7432e15228"}, - {file = "wrapt-1.16.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:b935ae30c6e7400022b50f8d359c03ed233d45b725cfdd299462f41ee5ffba6f"}, - {file = "wrapt-1.16.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:db98ad84a55eb09b3c32a96c576476777e87c520a34e2519d3e59c44710c002c"}, - {file = "wrapt-1.16.0-cp37-cp37m-win32.whl", hash = "sha256:9153ed35fc5e4fa3b2fe97bddaa7cbec0ed22412b85bcdaf54aeba92ea37428c"}, - {file = "wrapt-1.16.0-cp37-cp37m-win_amd64.whl", hash = "sha256:66dfbaa7cfa3eb707bbfcd46dab2bc6207b005cbc9caa2199bcbc81d95071a00"}, - {file = "wrapt-1.16.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1dd50a2696ff89f57bd8847647a1c363b687d3d796dc30d4dd4a9d1689a706f0"}, - {file = "wrapt-1.16.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:44a2754372e32ab315734c6c73b24351d06e77ffff6ae27d2ecf14cf3d229202"}, - {file = "wrapt-1.16.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e9723528b9f787dc59168369e42ae1c3b0d3fadb2f1a71de14531d321ee05b0"}, - {file = "wrapt-1.16.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dbed418ba5c3dce92619656802cc5355cb679e58d0d89b50f116e4a9d5a9603e"}, - {file = "wrapt-1.16.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:941988b89b4fd6b41c3f0bfb20e92bd23746579736b7343283297c4c8cbae68f"}, - {file = "wrapt-1.16.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6a42cd0cfa8ffc1915aef79cb4284f6383d8a3e9dcca70c445dcfdd639d51267"}, - {file = "wrapt-1.16.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1ca9b6085e4f866bd584fb135a041bfc32cab916e69f714a7d1d397f8c4891ca"}, - {file = "wrapt-1.16.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d5e49454f19ef621089e204f862388d29e6e8d8b162efce05208913dde5b9ad6"}, - {file = "wrapt-1.16.0-cp38-cp38-win32.whl", hash = "sha256:c31f72b1b6624c9d863fc095da460802f43a7c6868c5dda140f51da24fd47d7b"}, - {file = "wrapt-1.16.0-cp38-cp38-win_amd64.whl", hash = "sha256:490b0ee15c1a55be9c1bd8609b8cecd60e325f0575fc98f50058eae366e01f41"}, - {file = "wrapt-1.16.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9b201ae332c3637a42f02d1045e1d0cccfdc41f1f2f801dafbaa7e9b4797bfc2"}, - {file = "wrapt-1.16.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2076fad65c6736184e77d7d4729b63a6d1ae0b70da4868adeec40989858eb3fb"}, - {file = "wrapt-1.16.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5cd603b575ebceca7da5a3a251e69561bec509e0b46e4993e1cac402b7247b8"}, - {file = "wrapt-1.16.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b47cfad9e9bbbed2339081f4e346c93ecd7ab504299403320bf85f7f85c7d46c"}, - {file = "wrapt-1.16.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8212564d49c50eb4565e502814f694e240c55551a5f1bc841d4fcaabb0a9b8a"}, - {file = "wrapt-1.16.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5f15814a33e42b04e3de432e573aa557f9f0f56458745c2074952f564c50e664"}, - {file = "wrapt-1.16.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:db2e408d983b0e61e238cf579c09ef7020560441906ca990fe8412153e3b291f"}, - {file = "wrapt-1.16.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:edfad1d29c73f9b863ebe7082ae9321374ccb10879eeabc84ba3b69f2579d537"}, - {file = "wrapt-1.16.0-cp39-cp39-win32.whl", hash = "sha256:ed867c42c268f876097248e05b6117a65bcd1e63b779e916fe2e33cd6fd0d3c3"}, - {file = "wrapt-1.16.0-cp39-cp39-win_amd64.whl", hash = "sha256:eb1b046be06b0fce7249f1d025cd359b4b80fc1c3e24ad9eca33e0dcdb2e4a35"}, - {file = "wrapt-1.16.0-py3-none-any.whl", hash = "sha256:6906c4100a8fcbf2fa735f6059214bb13b97f75b1a61777fcf6432121ef12ef1"}, - {file = "wrapt-1.16.0.tar.gz", hash = "sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d"}, -] - [[package]] name = "zanj" version = "0.2.2" @@ -2190,4 +1941,4 @@ zanj = ["zanj"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "6992867457041d99bcd9915054e5e1d3c666f8a6eee2aba70835323fa85330ce" +content-hash = "e2a09de8d7c53e28a74a03217b0d45ccd5f48ba2c3993f0ea07baef8a876d91b" diff --git a/pyproject.toml b/pyproject.toml index 2b230115..1aa89971 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,18 +63,21 @@ filterwarnings = [ "ignore::muutils.json_serialize.serializable_dataclass.ZanjMissingWarning", # don't show warning for missing zanj (can't have as a dep since zanj depends on muutils) ] -[tool.pycln] -all = true -exclude = ["tests/input_data", "tests/junk_data"] +[tool.ruff] +# Exclude the directories specified in the global excludes +exclude = ["tests/input_data", "tests/junk_data", "muutils/_wip"] +# # allow trailing commas +# no-trailing-comma = false -[tool.isort] -profile = "black" -ignore_comments = false -extend_skip = ["tests/input_data", "tests/junk_data"] +# # Per-file rule ignores +# [tool.ruff.per-file-ignores] +# # Ignore the "F401" (unused imports) rule for __init__.py files +# "__init__.py" = ["F401"] -[tool.black] -extend-exclude = "tests/input_data" +[tool.pycln] +all = true +exclude = ["tests/input_data", "tests/junk_data", "muutils/_wip"] [tool.mypy] -exclude = ['_wip', "tests/input_data", "tests/junk_data"] -show_error_codes = true +exclude = ["tests/input_data", "tests/junk_data", "muutils/_wip"] +show_error_codes = true \ No newline at end of file From ae1cda870c9fda4789155156c77d922ded3c943b Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 14:28:47 -0700 Subject: [PATCH 115/158] run safe format --- muutils/kappa.py | 2 +- muutils/logger/logger.py | 4 ++-- muutils/misc.py | 1 - muutils/mlutils.py | 4 +--- muutils/nbutils/configure_notebook.py | 10 +++++++++- muutils/nbutils/run_notebook_tests.py | 4 +--- muutils/sysinfo.py | 1 - muutils/validate_type.py | 1 - .../test_serializable_dataclass.py | 6 ++++-- tests/unit/test_mlutils.py | 1 - tests/unit/test_statcounter.py | 1 - tests/unit/test_tensor_utils.py | 1 - 12 files changed, 18 insertions(+), 18 deletions(-) diff --git a/muutils/kappa.py b/muutils/kappa.py index 888d9c2b..819f1ad6 100644 --- a/muutils/kappa.py +++ b/muutils/kappa.py @@ -1,5 +1,5 @@ """anonymous getitem class - + util for constructing a class which has a getitem method which just calls a function a `lambda` is an anonymous function: kappa is the letter before lambda in the greek alphabet, diff --git a/muutils/logger/logger.py b/muutils/logger/logger.py index 20d46755..fca5a007 100644 --- a/muutils/logger/logger.py +++ b/muutils/logger/logger.py @@ -2,8 +2,8 @@ - `SimpleLogger` is an extremely simple logger that can write to both console and a file - `Logger` class handles levels in a slightly different way than default python `logging`, - and also has "streams" which allow for different sorts of output in the same logger - this was mostly made with training models in mind and storing both metadata and loss + and also has "streams" which allow for different sorts of output in the same logger + this was mostly made with training models in mind and storing both metadata and loss - `TimerContext` is a context manager that can be used to time the duration of a block of code """ diff --git a/muutils/misc.py b/muutils/misc.py index dd75ad4c..fdc9453f 100644 --- a/muutils/misc.py +++ b/muutils/misc.py @@ -309,7 +309,6 @@ def str_to_numeric( # decimals else: - try: result = int(quantity) except ValueError: diff --git a/muutils/mlutils.py b/muutils/mlutils.py index ced9494e..b08592a9 100644 --- a/muutils/mlutils.py +++ b/muutils/mlutils.py @@ -115,9 +115,7 @@ def get_checkpoint_paths_for_run( - a wildcard for the iteration number """ - assert ( - run_path.is_dir() - ), f"Model path {run_path} is not a directory (expect run directory, not model files)" + assert run_path.is_dir(), f"Model path {run_path} is not a directory (expect run directory, not model files)" return [ (int(checkpoint_path.stem.split("_")[-1].split(".")[0]), checkpoint_path) diff --git a/muutils/nbutils/configure_notebook.py b/muutils/nbutils/configure_notebook.py index e70116aa..db314e7d 100644 --- a/muutils/nbutils/configure_notebook.py +++ b/muutils/nbutils/configure_notebook.py @@ -90,7 +90,15 @@ def setup_plots( close_after_plotshow: bool = False, ) -> None: """Set up plot saving/rendering options""" - global PLOT_MODE, CONVERSION_PLOTMODE_OVERRIDE, FIG_COUNTER, FIG_OUTPUT_FMT, FIG_NUMBERED_FNAME, FIG_CONFIG, FIG_BASEPATH, CLOSE_AFTER_PLOTSHOW + global \ + PLOT_MODE, \ + CONVERSION_PLOTMODE_OVERRIDE, \ + FIG_COUNTER, \ + FIG_OUTPUT_FMT, \ + FIG_NUMBERED_FNAME, \ + FIG_CONFIG, \ + FIG_BASEPATH, \ + CLOSE_AFTER_PLOTSHOW # set plot mode, handling override if CONVERSION_PLOTMODE_OVERRIDE is not None: diff --git a/muutils/nbutils/run_notebook_tests.py b/muutils/nbutils/run_notebook_tests.py index 74a44fd1..7dd8b37b 100644 --- a/muutils/nbutils/run_notebook_tests.py +++ b/muutils/nbutils/run_notebook_tests.py @@ -67,9 +67,7 @@ def run_notebook_tests( output_file: Path = file.with_suffix(CI_output_suffix) print(f" Output in {output_file}") - command: str = ( - f"{run_python_cmd} {root_relative_to_notebooks / file} > {root_relative_to_notebooks / output_file} 2>&1" - ) + command: str = f"{run_python_cmd} {root_relative_to_notebooks / file} > {root_relative_to_notebooks / output_file} 2>&1" process: subprocess.CompletedProcess = subprocess.run( command, shell=True, text=True ) diff --git a/muutils/sysinfo.py b/muutils/sysinfo.py index 831783b5..11a3979c 100644 --- a/muutils/sysinfo.py +++ b/muutils/sysinfo.py @@ -154,7 +154,6 @@ def git_info(with_log: bool = False) -> dict: "git status": git_status, } else: - output: dict = { "git version": git_version["stdout"], "git status": git_status, diff --git a/muutils/validate_type.py b/muutils/validate_type.py index ae218e1c..e8b3de23 100644 --- a/muutils/validate_type.py +++ b/muutils/validate_type.py @@ -51,7 +51,6 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: # generic alias, more complicated item_type: type if isinstance(expected_type, GenericAliasTypes): - if origin is list: # no args if len(args) == 0: diff --git a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py index e646825f..1e9372c7 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py @@ -221,7 +221,8 @@ def full_name(self) -> str: print(serialized_data) loaded_instance = _loading_test_wrapper( - MyClass, serialized_data # , assert_record_len=3 + MyClass, + serialized_data, # , assert_record_len=3 ) print(loaded_instance) @@ -344,7 +345,8 @@ def test_nested_with_container(): assert serialized == expected_ser loaded = _loading_test_wrapper( - Nested_with_Container, serialized # , assert_record_len=12 + Nested_with_Container, + serialized, # , assert_record_len=12 ) assert loaded == instance diff --git a/tests/unit/test_mlutils.py b/tests/unit/test_mlutils.py index fd91c761..f1fb1e63 100644 --- a/tests/unit/test_mlutils.py +++ b/tests/unit/test_mlutils.py @@ -26,7 +26,6 @@ def test_get_checkpoint_paths_for_run(): def test_register_method(recwarn): - class TestEvalsA: evals = {} diff --git a/tests/unit/test_statcounter.py b/tests/unit/test_statcounter.py index 27f6ebae..cab9fbc2 100644 --- a/tests/unit/test_statcounter.py +++ b/tests/unit/test_statcounter.py @@ -49,7 +49,6 @@ def test_statcounter() -> None: # arrs.append(np.random.randint(i, j, size=1000)) for a in arrs: - r = _compare_np_custom(a) assert all( diff --git a/tests/unit/test_tensor_utils.py b/tests/unit/test_tensor_utils.py index 77aa8000..97123fb0 100644 --- a/tests/unit/test_tensor_utils.py +++ b/tests/unit/test_tensor_utils.py @@ -81,7 +81,6 @@ def test_compare_state_dicts(): def test_get_dict_shapes(): - x = {"a": torch.rand(2, 3), "b": torch.rand(1, 3, 5), "c": torch.rand(2)} x_shapes = get_dict_shapes(x) assert x_shapes == {"a": (2, 3), "b": (1, 3, 5), "c": (2,)} From 6e9e42572d4bcc76646133583296cd1660d3fcd8 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 14:31:35 -0700 Subject: [PATCH 116/158] ruff fix, not just format --- makefile | 1 + 1 file changed, 1 insertion(+) diff --git a/makefile b/makefile index 51fa3c3b..44c29682 100644 --- a/makefile +++ b/makefile @@ -135,6 +135,7 @@ setup-format: format: @echo "format the source code" $(PYTHON) -m ruff format --config $(PYPROJECT) . + $(PYTHON) -m ruff check --fix --config $(PYPROJECT) . $(PYTHON) -m pycln --config $(PYPROJECT) --all . .PHONY: check-format From 2428dd01c7b963320a270804335e6ddc648f101a Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 14:31:43 -0700 Subject: [PATCH 117/158] run new make format --- muutils/logger/logger.py | 2 +- muutils/misc.py | 4 ++-- muutils/sysinfo.py | 3 +-- .../serializable_dataclass/test_sdc_defaults.py | 2 -- .../serializable_dataclass/test_serializable_dataclass.py | 1 - tests/unit/validate_type/test_validate_type_special.py | 1 - 6 files changed, 4 insertions(+), 9 deletions(-) diff --git a/muutils/logger/logger.py b/muutils/logger/logger.py index fca5a007..b4165f95 100644 --- a/muutils/logger/logger.py +++ b/muutils/logger/logger.py @@ -195,7 +195,7 @@ def log( # type: ignore # yes, the signatures are different here. else: lvl = self._default_level - assert not lvl is None, "lvl should not be None at this point" + assert lvl is not None, "lvl should not be None at this point" # print to console with formatting # ======================================== diff --git a/muutils/misc.py b/muutils/misc.py index fdc9453f..7382bdba 100644 --- a/muutils/misc.py +++ b/muutils/misc.py @@ -286,7 +286,7 @@ def str_to_numeric( # fractions if "/" in quantity: try: - assert quantity.count("/") == 1, f"too many '/'" + assert quantity.count("/") == 1, "too many '/'" # split and strip num, den = quantity.split("/") num = num.strip() @@ -299,7 +299,7 @@ def str_to_numeric( # assert that both are digits assert ( num.isdigit() and den.isdigit() - ), f"numerator and denominator must be digits" + ), "numerator and denominator must be digits" # return the fraction result = num_sign * ( int(num) / int(den) diff --git a/muutils/sysinfo.py b/muutils/sysinfo.py index 11a3979c..a19c9fea 100644 --- a/muutils/sysinfo.py +++ b/muutils/sysinfo.py @@ -1,6 +1,5 @@ from __future__ import annotations -import os import subprocess import sys import typing @@ -105,7 +104,7 @@ def pytorch() -> dict: "device": current_device, "name": dev_prop.name, "version": { - f"major": dev_prop.major, + "major": dev_prop.major, "minor": dev_prop.minor, }, "total_memory": dev_prop.total_memory, diff --git a/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py b/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py index 4119c73a..015faa20 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py @@ -1,9 +1,7 @@ from __future__ import annotations -import sys from typing import Any -import pytest from muutils.json_serialize import ( JsonSerializer, diff --git a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py index 1e9372c7..a8a189a2 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py @@ -1,6 +1,5 @@ from __future__ import annotations -import sys import typing from typing import Any diff --git a/tests/unit/validate_type/test_validate_type_special.py b/tests/unit/validate_type/test_validate_type_special.py index 4ca4f069..cd33500a 100644 --- a/tests/unit/validate_type/test_validate_type_special.py +++ b/tests/unit/validate_type/test_validate_type_special.py @@ -1,7 +1,6 @@ from __future__ import annotations import typing -from typing import Any, Optional, Union import pytest From 8e6fa71533a494dfbe3ff079b7cddcfd77777dbb Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 14:37:10 -0700 Subject: [PATCH 118/158] manual fixes --- muutils/dictmagic.py | 8 ++++---- muutils/json_serialize/array.py | 2 +- muutils/json_serialize/json_serialize.py | 2 +- muutils/json_serialize/util.py | 10 +++++----- muutils/nbutils/configure_notebook.py | 2 +- tests/unit/test_kappa.py | 16 ++++++++-------- tests/unit/test_tensor_utils.py | 4 +++- 7 files changed, 23 insertions(+), 21 deletions(-) diff --git a/muutils/dictmagic.py b/muutils/dictmagic.py index 28ea0d02..c34ac34d 100644 --- a/muutils/dictmagic.py +++ b/muutils/dictmagic.py @@ -359,8 +359,8 @@ def tuple_dims_replace( return tuple(dims_names_map.get(x, x) for x in t) -TensorDict = typing.Dict[str, "torch.Tensor|np.ndarray"] # type: ignore[name-defined] -TensorIterable = Iterable[typing.Tuple[str, "torch.Tensor|np.ndarray"]] # type: ignore[name-defined] +TensorDict = typing.Dict[str, "torch.Tensor|np.ndarray"] # type: ignore[name-defined] # noqa: F821 +TensorIterable = Iterable[typing.Tuple[str, "torch.Tensor|np.ndarray"]] # type: ignore[name-defined] # noqa: F821 TensorDictFormats = Literal["dict", "json", "yaml", "yml"] @@ -465,10 +465,10 @@ def condense_tensor_dict( # identity function for shapes_convert if not provided if shapes_convert is None: - shapes_convert = lambda x: x + shapes_convert = lambda x: x # noqa: E731 # convert to iterable - data_items: "Iterable[tuple[str, Union[torch.Tensor,np.ndarray]]]" = ( # type: ignore + data_items: "Iterable[tuple[str, Union[torch.Tensor,np.ndarray]]]" = ( # type: ignore # noqa: F821 data.items() if hasattr(data, "items") and callable(data.items) else data # type: ignore ) diff --git a/muutils/json_serialize/array.py b/muutils/json_serialize/array.py index 4a810f0c..d08c9bea 100644 --- a/muutils/json_serialize/array.py +++ b/muutils/json_serialize/array.py @@ -43,7 +43,7 @@ def arr_metadata(arr) -> dict[str, list[int] | str | int]: def serialize_array( - jser: "JsonSerializer", # type: ignore[name-defined] + jser: "JsonSerializer", # type: ignore[name-defined] # noqa: F821 arr: np.ndarray, path: str | Sequence[str | int], array_mode: ArrayMode | None = None, diff --git a/muutils/json_serialize/json_serialize.py b/muutils/json_serialize/json_serialize.py index 6f290436..23cf4d5b 100644 --- a/muutils/json_serialize/json_serialize.py +++ b/muutils/json_serialize/json_serialize.py @@ -10,7 +10,7 @@ from muutils.json_serialize.array import ArrayMode, serialize_array except ImportError as e: ArrayMode = str # type: ignore[misc] - serialize_array = lambda *args, **kwargs: None + serialize_array = lambda *args, **kwargs: None # noqa: E731 warnings.warn( f"muutils.json_serialize.array could not be imported probably because missing numpy, array serialization will not work: \n{e}", ImportWarning, diff --git a/muutils/json_serialize/util.py b/muutils/json_serialize/util.py index 87d80d03..b17512be 100644 --- a/muutils/json_serialize/util.py +++ b/muutils/json_serialize/util.py @@ -62,19 +62,19 @@ def __contains__(self, x: Any) -> bool: return True -def isinstance_namedtuple(x): +def isinstance_namedtuple(x: Any) -> bool: """checks if `x` is a `namedtuple` credit to https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple """ - t = type(x) - b = t.__bases__ + t: type = type(x) + b: tuple = t.__bases__ if len(b) != 1 or b[0] != tuple: return False - f = getattr(t, "_fields", None) + f: tuple = getattr(t, "_fields", None) if not isinstance(f, tuple): return False - return all(type(n) == str for n in f) + return all(isinstance(n, str) for n in f) def try_catch(func: Callable): diff --git a/muutils/nbutils/configure_notebook.py b/muutils/nbutils/configure_notebook.py index db314e7d..1e56d1ef 100644 --- a/muutils/nbutils/configure_notebook.py +++ b/muutils/nbutils/configure_notebook.py @@ -144,7 +144,7 @@ def setup_plots( plt.rcParams["savefig.format"] = FIG_OUTPUT_FMT if FIG_OUTPUT_FMT in TIKZPLOTLIB_FORMATS: try: - import tikzplotlib # type: ignore[import] + import tikzplotlib # type: ignore[import] # noqa: F401 except ImportError: warnings.warn( f"Tikzplotlib not installed. Cannot save figures in Tikz format '{FIG_OUTPUT_FMT}', things might break." diff --git a/tests/unit/test_kappa.py b/tests/unit/test_kappa.py index ab8b169c..860bbfd1 100644 --- a/tests/unit/test_kappa.py +++ b/tests/unit/test_kappa.py @@ -6,50 +6,50 @@ def test_Kappa_returns_Kappa_instance(): - func = lambda x: x**2 + func = lambda x: x**2 # noqa: E731 result = Kappa(func) assert isinstance(result, Mapping), "Kappa did not return a Mapping instance" def test_Kappa_getitem_calls_func(): - func = lambda x: x**2 + func = lambda x: x**2 # noqa: E731 result = Kappa(func) assert result[2] == 4, "__getitem__ did not correctly call the input function" def test_Kappa_doc_is_correctly_formatted(): - func = lambda x: x**2 + func = lambda x: x**2 # noqa: E731 result = Kappa(func) expected_doc = _BASE_DOC + "None" assert result.doc == expected_doc, "doc was not correctly formatted" def test_Kappa_getitem_works_with_different_functions(): - func = lambda x: x + 1 + func = lambda x: x + 1 # noqa: E731 result = Kappa(func) assert result[2] == 3, "__getitem__ did not correctly call the input function" - func = lambda x: str(x) + func = lambda x: str(x) # noqa: E731 result = Kappa(func) assert result[2] == "2", "__getitem__ did not correctly call the input function" def test_Kappa_iter_raises_NotImplementedError(): - func = lambda x: x**2 + func = lambda x: x**2 # noqa: E731 result = Kappa(func) with pytest.raises(NotImplementedError): iter(result) def test_Kappa_len_raises_NotImplementedError(): - func = lambda x: x**2 + func = lambda x: x**2 # noqa: E731 result = Kappa(func) with pytest.raises(NotImplementedError): len(result) def test_Kappa_doc_works_with_function_with_docstring(): - func = lambda x: x**2 + func = lambda x: x**2 # noqa: E731 func.__doc__ = "This is a test function" result = Kappa(func) expected_doc = _BASE_DOC + "This is a test function" diff --git a/tests/unit/test_tensor_utils.py b/tests/unit/test_tensor_utils.py index 97123fb0..89e9e398 100644 --- a/tests/unit/test_tensor_utils.py +++ b/tests/unit/test_tensor_utils.py @@ -32,7 +32,9 @@ def test_jaxtype_factory(): assert "array_type = " in ATensor.__doc__ x = ATensor[(1, 2, 3), np.float32] - x = ATensor["dim1 dim2", np.float32] + print(x) + y = ATensor["dim1 dim2", np.float32] + print(y) def test_numpy_to_torch_dtype(): From 8f6467123aedf34cb3f0b4ae10d765887ff93f41 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 14:39:38 -0700 Subject: [PATCH 119/158] more manual fixes --- muutils/json_serialize/json_serialize.py | 2 +- muutils/logger/timing.py | 5 +++-- muutils/nbutils/configure_notebook.py | 4 ++-- muutils/tensor_utils.py | 8 ++++---- tests/unit/logger/test_timer_context.py | 1 + tests/unit/test_dictmagic.py | 3 ++- 6 files changed, 13 insertions(+), 10 deletions(-) diff --git a/muutils/json_serialize/json_serialize.py b/muutils/json_serialize/json_serialize.py index 23cf4d5b..67b33548 100644 --- a/muutils/json_serialize/json_serialize.py +++ b/muutils/json_serialize/json_serialize.py @@ -125,7 +125,7 @@ def serialize(self) -> dict: def _serialize_override_serialize_func( self: "JsonSerializer", obj: Any, path: ObjectPath ) -> JSONitem: - obj_cls: type = type(obj) + # obj_cls: type = type(obj) # if hasattr(obj_cls, "_register_self") and callable(obj_cls._register_self): # obj_cls._register_self() diff --git a/muutils/logger/timing.py b/muutils/logger/timing.py index c0c8e0c6..1b7b2e99 100644 --- a/muutils/logger/timing.py +++ b/muutils/logger/timing.py @@ -81,6 +81,7 @@ def get_progress_default(self, i: int) -> str: timing_raw: dict[str, float] = self.get_timing_raw(i) percent_str: str = str(int(timing_raw["percent"] * 100)).ljust(2) - iters_str: str = f"{str(i).ljust(self.total_str_len)}/{self.n_total}" - timing_str: str + # TODO: get_progress_default + # iters_str: str = f"{str(i).ljust(self.total_str_len)}/{self.n_total}" + # timing_str: str return f"{percent_str}% {self.get_pbar(i)}" diff --git a/muutils/nbutils/configure_notebook.py b/muutils/nbutils/configure_notebook.py index 1e56d1ef..62352f50 100644 --- a/muutils/nbutils/configure_notebook.py +++ b/muutils/nbutils/configure_notebook.py @@ -33,7 +33,7 @@ class PlotlyNotInstalledWarning(UserWarning): IN_JUPYTER = False # muutils imports -from muutils.mlutils import get_device, set_reproducibility +from muutils.mlutils import get_device, set_reproducibility # noqa: E402 # handling figures PlottingMode = typing.Literal["ignore", "inline", "widget", "save"] @@ -198,7 +198,7 @@ def configure_notebook( fig_config: dict | None = None, fig_basepath: str | None = None, close_after_plotshow: bool = False, -) -> "torch.device|None": # type: ignore[name-defined] +) -> "torch.device|None": # type: ignore[name-defined] # noqa: F821 """Shared Jupyter notebook setup steps: - Set random seeds and library reproducibility settings - Set device based on availability diff --git a/muutils/tensor_utils.py b/muutils/tensor_utils.py index 8722546e..abc61ea2 100644 --- a/muutils/tensor_utils.py +++ b/muutils/tensor_utils.py @@ -280,11 +280,11 @@ def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dt def pad_tensor( - tensor: jaxtyping.Shaped[torch.Tensor, "dim1"], + tensor: jaxtyping.Shaped[torch.Tensor, "dim1"], # noqa: F821 padded_length: int, pad_value: float = 0.0, rpad: bool = False, -) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]: +) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]: # noqa: F821 """pad a 1-d tensor on the left with pad_value to length `padded_length` set `rpad = True` to pad on the right instead""" @@ -319,11 +319,11 @@ def rpad_tensor( def pad_array( - array: jaxtyping.Shaped[np.ndarray, "dim1"], + array: jaxtyping.Shaped[np.ndarray, "dim1"], # noqa: F821 padded_length: int, pad_value: float = 0.0, rpad: bool = False, -) -> jaxtyping.Shaped[np.ndarray, "padded_length"]: +) -> jaxtyping.Shaped[np.ndarray, "padded_length"]: # noqa: F821 """pad a 1-d array on the left with pad_value to length `padded_length` set `rpad = True` to pad on the right instead""" diff --git a/tests/unit/logger/test_timer_context.py b/tests/unit/logger/test_timer_context.py index cd9d8bc2..5d164abd 100644 --- a/tests/unit/logger/test_timer_context.py +++ b/tests/unit/logger/test_timer_context.py @@ -6,6 +6,7 @@ def test_timer_context() -> None: with TimerContext() as timer: x: float = 1.0 + print(x) assert isinstance(timer.start_time, float) assert isinstance(timer.end_time, float) diff --git a/tests/unit/test_dictmagic.py b/tests/unit/test_dictmagic.py index 4407bb54..d95f6691 100644 --- a/tests/unit/test_dictmagic.py +++ b/tests/unit/test_dictmagic.py @@ -262,7 +262,8 @@ def test_condense_tensor_dict_basic(tensor_data): def test_condense_tensor_dict_shapes_convert(tensor_data): - shapes_convert = lambda x: x # Returning the actual shape tuple + # Returning the actual shape tuple + shapes_convert = lambda x: x # noqa: E731 assert condense_tensor_dict( tensor_data, shapes_convert=shapes_convert, From 5e51679889b5f3d46e721f4fa25a0402e79a69c5 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 14:40:42 -0700 Subject: [PATCH 120/158] fixed using --unsafe-fixes --- tests/unit/validate_type/test_validate_type.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/unit/validate_type/test_validate_type.py b/tests/unit/validate_type/test_validate_type.py index de629dfb..8784b788 100644 --- a/tests/unit/validate_type/test_validate_type.py +++ b/tests/unit/validate_type/test_validate_type.py @@ -370,28 +370,28 @@ def test_validate_type_complex(): assert validate_type({1, (2, 3)}, typing.Set[Union[int, typing.Tuple[int, int]]]) assert validate_type((1, ("a", "b")), typing.Tuple[int, typing.Tuple[str, str]]) assert validate_type([{"key": "value"}], typing.List[typing.Dict[str, str]]) - assert validate_type([{"key": 2}], typing.List[typing.Dict[str, str]]) == False + assert validate_type([{"key": 2}], typing.List[typing.Dict[str, str]]) is False assert validate_type([[1, 2], [3, 4]], typing.List[typing.List[int]]) - assert validate_type([[1, 2], [3, "4"]], typing.List[typing.List[int]]) == False + assert validate_type([[1, 2], [3, "4"]], typing.List[typing.List[int]]) is False assert validate_type([(1, 2), (3, 4)], typing.List[typing.Tuple[int, int]]) assert ( - validate_type([(1, 2), (3, "4")], typing.List[typing.Tuple[int, int]]) == False + validate_type([(1, 2), (3, "4")], typing.List[typing.Tuple[int, int]]) is False ) assert validate_type({1: "one", 2: "two"}, typing.Dict[int, str]) - assert validate_type({1: "one", 2: 2}, typing.Dict[int, str]) == False + assert validate_type({1: "one", 2: 2}, typing.Dict[int, str]) is False assert validate_type([(1, "one"), (2, "two")], typing.List[typing.Tuple[int, str]]) assert ( validate_type([(1, "one"), (2, 2)], typing.List[typing.Tuple[int, str]]) - == False + is False ) assert validate_type({1: [1, 2], 2: [3, 4]}, typing.Dict[int, typing.List[int]]) assert ( validate_type({1: [1, 2], 2: [3, "4"]}, typing.Dict[int, typing.List[int]]) - == False + is False ) assert validate_type([(1, "a"), (2, "b")], typing.List[typing.Tuple[int, str]]) assert ( - validate_type([(1, "a"), (2, 2)], typing.List[typing.Tuple[int, str]]) == False + validate_type([(1, "a"), (2, 2)], typing.List[typing.Tuple[int, str]]) is False ) From af28d613acd493ab9dc81073e6010d8999e1f708 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 14:42:49 -0700 Subject: [PATCH 121/158] fix setup of formatters/linters in CI --- .github/workflows/checks.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 3c4237f4..8ebe0d6d 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -19,7 +19,7 @@ jobs: fetch-depth: 0 - name: Install linters - run: pip install pycln isort black + run: make setup-format RUN_GLOBAL=1 - name: Run Format Checks run: make check-format RUN_GLOBAL=1 From c4af1038779fe76f786f23141759e92bf7150126 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 14:43:15 -0700 Subject: [PATCH 122/158] typing fix --- muutils/json_serialize/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/muutils/json_serialize/util.py b/muutils/json_serialize/util.py index b17512be..e47ba597 100644 --- a/muutils/json_serialize/util.py +++ b/muutils/json_serialize/util.py @@ -71,7 +71,7 @@ def isinstance_namedtuple(x: Any) -> bool: b: tuple = t.__bases__ if len(b) != 1 or b[0] != tuple: return False - f: tuple = getattr(t, "_fields", None) + f: Any = getattr(t, "_fields", None) if not isinstance(f, tuple): return False return all(isinstance(n, str) for n in f) From cf4417813631064d550ddc17cb4960b54e9dde30 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 15:15:24 -0700 Subject: [PATCH 123/158] better handling of only format deps --- .lastversion => .github/.lastversion | 0 .../dev-requirements.txt | 1 - .github/lint-requirements.txt | 8 +++++++ .gitignore | 7 +++--- makefile | 23 +++++++++++-------- poetry.lock | 3 ++- pyproject.toml | 22 +++++++++--------- 7 files changed, 39 insertions(+), 25 deletions(-) rename .lastversion => .github/.lastversion (100%) rename dev-requirements.txt => .github/dev-requirements.txt (98%) create mode 100644 .github/lint-requirements.txt diff --git a/.lastversion b/.github/.lastversion similarity index 100% rename from .lastversion rename to .github/.lastversion diff --git a/dev-requirements.txt b/.github/dev-requirements.txt similarity index 98% rename from dev-requirements.txt rename to .github/dev-requirements.txt index 18b678be..09b9d1d9 100644 --- a/dev-requirements.txt +++ b/.github/dev-requirements.txt @@ -43,7 +43,6 @@ pytest==8.2.2 ; python_version >= "3.8" and python_version < "4.0" python-dateutil==2.9.0.post0 ; python_version >= "3.8" and python_version < "4.0" pyyaml==6.0.1 ; python_version >= "3.8" and python_version < "4" rich==13.7.1 ; python_version >= "3.8" and python_version < "4" -ruff==0.4.9 ; python_version >= "3.8" and python_version < "4.0" shellingham==1.5.4 ; python_version >= "3.8" and python_version < "4" six==1.16.0 ; python_version >= "3.8" and python_version < "4.0" stack-data==0.6.3 ; python_version >= "3.10" and python_version < "4.0" diff --git a/.github/lint-requirements.txt b/.github/lint-requirements.txt new file mode 100644 index 00000000..fed6c25c --- /dev/null +++ b/.github/lint-requirements.txt @@ -0,0 +1,8 @@ +colorama==0.4.6 ; python_version >= "3.8" and python_version < "4.0" and sys_platform == "win32" +exceptiongroup==1.2.1 ; python_version >= "3.8" and python_version < "3.11" +iniconfig==2.0.0 ; python_version >= "3.8" and python_version < "4.0" +packaging==24.1 ; python_version >= "3.8" and python_version < "4.0" +pluggy==1.5.0 ; python_version >= "3.8" and python_version < "4.0" +pytest==8.2.2 ; python_version >= "3.8" and python_version < "4.0" +ruff==0.4.9 ; python_version >= "3.8" and python_version < "4.0" +tomli==2.0.1 ; python_version >= "3.8" and python_version < "3.11" diff --git a/.gitignore b/.gitignore index 2d99a757..6508e22b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# local stuff (pypi token, commit log) +.github/local/** + # this one is cursed tests/unit/validate_type/test_validate_type_MODERN.py # test notebook @@ -5,9 +8,7 @@ _test.ipynb # junk data JUNK_DATA_PATH/ junk_data -# misc -.pypi-token -.commit_log + .vscode/ # caches __pycache__/ diff --git a/makefile b/makefile index 44c29682..07cbb8f4 100644 --- a/makefile +++ b/makefile @@ -12,19 +12,22 @@ COVERAGE_REPORTS_DIR := docs/coverage TESTS_DIR := tests/unit # temp directory to clean up TESTS_TEMP_DIR := tests/_temp +# dev and lint requirements.txt files +REQ_DEV := .github/dev-requirements.txt +REQ_LINT := .github/lint-requirements.txt # probably don't change these: # -------------------------------------------------- # will print this token when publishing -PYPI_TOKEN_FILE := .pypi-token +PYPI_TOKEN_FILE := .github/local/.pypi-token # the last version that was auto-uploaded. will use this to create a commit log for version tag -LAST_VERSION_FILE := .lastversion +LAST_VERSION_FILE := .github/.lastversion # where the pyproject.toml file is PYPROJECT := pyproject.toml # base python to use. Will add `poetry run` in front of this if `RUN_GLOBAL` is not set to 1 PYTHON_BASE := python # where the commit log will be stored -COMMIT_LOG_FILE := .commit_log +COMMIT_LOG_FILE := .github/local/.commit_log @@ -129,7 +132,7 @@ version: .PHONY: setup-format setup-format: @echo "install only packages needed for formatting, direct via pip (useful for CI)" - $(PYTHON_BASE) -c 'import re,tomllib; cfg = tomllib.load(open("$(PYPROJECT)", "rb")); deps = [(pkg, re.match(r"^\D*(\d.*)", ver).group(1)) for pkg, ver in cfg["tool"]["poetry"]["group"]["dev"]["dependencies"].items() if pkg in ["ruff", "pycln"]]; print(" ".join([f"{pkg}=={ver}" for pkg,ver in deps]))' | xargs $(PYTHON) -m pip install + $(PYTHON) -m pip install -r $(REQ_LINT) .PHONY: format format: @@ -203,19 +206,21 @@ verify-git: # no zanj, it gets special treatment because it depends on muutils # without urls since pytorch extra index breaks things # no torch because we install it manually in CI -EXPORT_ARGS := -E array_no_torch -E notebook --with dev --without-hashes --without-urls +EXPORT_ARGS := -E array_no_torch -E notebook --with dev --with lint --without-hashes --without-urls .PHONY: dep-dev dep-dev: - @echo "exporting dev and extras dependencies to dev-requirements.txt" + @echo "exporting dev and extras deps to $(REQ_DEV), lint/format deps to $(REQ_LINT)" poetry update - poetry export $(EXPORT_ARGS) --output dev-requirements.txt + poetry export $(EXPORT_ARGS) --output $(REQ_DEV) + poetry export --only lint --without-hashes --without-urls --output $(REQ_LINT) .PHONY: check-dep-dev check-dep-dev: - @echo "checking requirements.txt matches poetry dependencies" + @echo "checking poetry lock is good, exported requirements match poetry" poetry check --lock - poetry export $(EXPORT_ARGS) | diff - dev-requirements.txt + poetry export $(EXPORT_ARGS) | diff - $(REQ_DEV) + poetry export --only lint --without-hashes --without-urls | diff - $(REQ_LINT) .PHONY: build build: diff --git a/poetry.lock b/poetry.lock index b858b705..4c2ba0c8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1935,10 +1935,11 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [extras] array = ["jaxtyping", "numpy", "torch"] array-no-torch = ["jaxtyping", "numpy"] +lint = [] notebook = ["ipython"] zanj = ["zanj"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "e2a09de8d7c53e28a74a03217b0d45ccd5f48ba2c3993f0ea07baef8a876d91b" +content-hash = "ac943e0702af9778b24bc0e2a2ded1b3e59e8f3dd46c8a6636226aeace69a547" diff --git a/pyproject.toml b/pyproject.toml index 1aa89971..b333fb63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,18 +25,7 @@ jaxtyping = { version = "^0.2.12", optional = true } ipython = { version = "^8.20.0", optional = true, python = "^3.10" } zanj = { version = "^0.2.2", optional = true, python = "^3.10" } -[tool.poetry.extras] -array = ["numpy", "torch", "jaxtyping"] -array_no_torch = ["numpy", "jaxtyping"] -notebook = ["ipython"] -zanj = ["zanj"] - [tool.poetry.group.dev.dependencies] -pytest = "^8.2.2" -ruff = "^0.4.8" -# black = "^24.1.1" -# pylint = "^2.16.4" -# isort = "^5.12.0" pycln = "^2.1.3" mypy = "^1.0.1" pytest-cov = "^4.1.0" @@ -44,6 +33,17 @@ coverage-badge = "^1.1.0" matplotlib = "^3.0.0" plotly = "^5.0.0" +[tool.poetry.group.lint.dependencies] +pytest = "^8.2.2" +ruff = "^0.4.8" + +[tool.poetry.extras] +array = ["numpy", "torch", "jaxtyping"] +array_no_torch = ["numpy", "jaxtyping"] +notebook = ["ipython"] +zanj = ["zanj"] +lint = ["ruff", "pycln"] + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" From 983428839929b7f56c85590c4ad62d312ac75036 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 15:17:02 -0700 Subject: [PATCH 124/158] typo --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b333fb63..dc0f7f5f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ ipython = { version = "^8.20.0", optional = true, python = "^3.10" } zanj = { version = "^0.2.2", optional = true, python = "^3.10" } [tool.poetry.group.dev.dependencies] -pycln = "^2.1.3" +pytest = "^8.2.2" mypy = "^1.0.1" pytest-cov = "^4.1.0" coverage-badge = "^1.1.0" @@ -34,7 +34,7 @@ matplotlib = "^3.0.0" plotly = "^5.0.0" [tool.poetry.group.lint.dependencies] -pytest = "^8.2.2" +pycln = "^2.1.3" ruff = "^0.4.8" [tool.poetry.extras] From 9647bb6d9b1fe4884b2e377fe97e39e313e42221 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 15:18:19 -0700 Subject: [PATCH 125/158] fix ref to dev-requirements.txt --- .github/workflows/checks.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 8ebe0d6d..684cbe1c 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -46,14 +46,14 @@ jobs: poetry self add poetry-plugin-export poetry self show plugins - - name: Check poetry.lock and dev-requirements.txt + - name: Check poetry.lock and .github/dev-requirements.txt run: make check-dep-dev - name: Install uv run: pip install uv - name: Install dependencies - run: uv pip install -r dev-requirements.txt --system --no-deps # we already should have all dependencies exported into dev-requirements.txt + run: uv pip install -r .github/dev-requirements.txt --system --no-deps # we already should have all dependencies exported into .github/dev-requirements.txt - name: Install torch (special) run: uv pip install torch==2.3.1+cpu --system --extra-index-url https://download.pytorch.org/whl/cpu @@ -100,7 +100,7 @@ jobs: - name: Install dependencies # install torch first to avoid pytorch index messing things up run: | - uv pip install -r dev-requirements.txt --system --no-deps + uv pip install -r .github/dev-requirements.txt --system --no-deps uv pip install torch==${{ matrix.versions.torch}}+cpu --system --extra-index-url https://download.pytorch.org/whl/cpu - name: Install muutils From 386a6d1cfae4d441313cbd9b8a53d6b1431cb0e6 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 15:19:41 -0700 Subject: [PATCH 126/158] make dep-dev --- .github/dev-requirements.txt | 1 + .github/lint-requirements.txt | 23 ++++++++++++++++------- poetry.lock | 2 +- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/.github/dev-requirements.txt b/.github/dev-requirements.txt index 09b9d1d9..18b678be 100644 --- a/.github/dev-requirements.txt +++ b/.github/dev-requirements.txt @@ -43,6 +43,7 @@ pytest==8.2.2 ; python_version >= "3.8" and python_version < "4.0" python-dateutil==2.9.0.post0 ; python_version >= "3.8" and python_version < "4.0" pyyaml==6.0.1 ; python_version >= "3.8" and python_version < "4" rich==13.7.1 ; python_version >= "3.8" and python_version < "4" +ruff==0.4.9 ; python_version >= "3.8" and python_version < "4.0" shellingham==1.5.4 ; python_version >= "3.8" and python_version < "4" six==1.16.0 ; python_version >= "3.8" and python_version < "4.0" stack-data==0.6.3 ; python_version >= "3.10" and python_version < "4.0" diff --git a/.github/lint-requirements.txt b/.github/lint-requirements.txt index fed6c25c..5c77b646 100644 --- a/.github/lint-requirements.txt +++ b/.github/lint-requirements.txt @@ -1,8 +1,17 @@ -colorama==0.4.6 ; python_version >= "3.8" and python_version < "4.0" and sys_platform == "win32" -exceptiongroup==1.2.1 ; python_version >= "3.8" and python_version < "3.11" -iniconfig==2.0.0 ; python_version >= "3.8" and python_version < "4.0" -packaging==24.1 ; python_version >= "3.8" and python_version < "4.0" -pluggy==1.5.0 ; python_version >= "3.8" and python_version < "4.0" -pytest==8.2.2 ; python_version >= "3.8" and python_version < "4.0" +click==8.1.7 ; python_version >= "3.8" and python_version < "4" +colorama==0.4.6 ; python_version >= "3.8" and python_version < "4" and platform_system == "Windows" +libcst==1.1.0 ; python_version >= "3.8" and python_version < "4" +markdown-it-py==3.0.0 ; python_version >= "3.8" and python_version < "4" +mdurl==0.1.2 ; python_version >= "3.8" and python_version < "4" +mypy-extensions==1.0.0 ; python_version >= "3.8" and python_version < "4" +pathspec==0.12.1 ; python_version >= "3.8" and python_version < "4" +pycln==2.4.0 ; python_version >= "3.8" and python_version < "4" +pygments==2.18.0 ; python_version >= "3.8" and python_version < "4" +pyyaml==6.0.1 ; python_version >= "3.8" and python_version < "4" +rich==13.7.1 ; python_version >= "3.8" and python_version < "4" ruff==0.4.9 ; python_version >= "3.8" and python_version < "4.0" -tomli==2.0.1 ; python_version >= "3.8" and python_version < "3.11" +shellingham==1.5.4 ; python_version >= "3.8" and python_version < "4" +tomlkit==0.12.5 ; python_version >= "3.8" and python_version < "4" +typer==0.12.3 ; python_version >= "3.8" and python_version < "4" +typing-extensions==4.12.2 ; python_version >= "3.8" and python_version < "4" +typing-inspect==0.9.0 ; python_version >= "3.8" and python_version < "4" diff --git a/poetry.lock b/poetry.lock index 4c2ba0c8..ea550967 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1942,4 +1942,4 @@ zanj = ["zanj"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "ac943e0702af9778b24bc0e2a2ded1b3e59e8f3dd46c8a6636226aeace69a547" +content-hash = "74fc33c5d521b3121b9894ec25831ab3a7637de2cff3fa04eb2128c211092914" From f34f28d6dab3263cfb1568aed6dfc50bdd01aa0d Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 15:21:20 -0700 Subject: [PATCH 127/158] make dep-dev --- poetry.lock | 3 +-- pyproject.toml | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/poetry.lock b/poetry.lock index ea550967..2af38d47 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1935,11 +1935,10 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [extras] array = ["jaxtyping", "numpy", "torch"] array-no-torch = ["jaxtyping", "numpy"] -lint = [] notebook = ["ipython"] zanj = ["zanj"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "74fc33c5d521b3121b9894ec25831ab3a7637de2cff3fa04eb2128c211092914" +content-hash = "ad931582fd4069a70774c8f616036c92cc008e0bacde72b07e2260f169fdb905" diff --git a/pyproject.toml b/pyproject.toml index dc0f7f5f..434d3de1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,6 @@ array = ["numpy", "torch", "jaxtyping"] array_no_torch = ["numpy", "jaxtyping"] notebook = ["ipython"] zanj = ["zanj"] -lint = ["ruff", "pycln"] [build-system] requires = ["poetry-core"] From 7bdf44562e7dc8447be310711edf3210741af659 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 15:55:43 -0700 Subject: [PATCH 128/158] fixing some todos, get rid of _loading_test_wrapper --- pyproject.toml | 10 ------ .../test_sdc_defaults.py | 14 ++------ .../test_sdc_properties_nested.py | 11 ++----- .../test_serializable_dataclass.py | 32 ++++++------------- 4 files changed, 14 insertions(+), 53 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 434d3de1..9533aed6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,9 +52,6 @@ build-backend = "poetry.core.masonry.api" # url = "https://download.pytorch.org/whl/cpu" # priority = "explicit" -# TODO: make all of the following ignored across all formatting/linting -# tests/input_data, tests/junk_data, muutils/_wip - [tool.pytest.ini_options] filterwarnings = [ "ignore::muutils.nbutils.configure_notebook.UnknownFigureFormatWarning", # don't show warning for unknown figure format @@ -65,13 +62,6 @@ filterwarnings = [ [tool.ruff] # Exclude the directories specified in the global excludes exclude = ["tests/input_data", "tests/junk_data", "muutils/_wip"] -# # allow trailing commas -# no-trailing-comma = false - -# # Per-file rule ignores -# [tool.ruff.per-file-ignores] -# # Ignore the "F401" (unused imports) rule for __init__.py files -# "__init__.py" = ["F401"] [tool.pycln] all = true diff --git a/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py b/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py index 015faa20..92bce05b 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py @@ -1,6 +1,5 @@ from __future__ import annotations -from typing import Any from muutils.json_serialize import ( @@ -13,13 +12,6 @@ # pylint: disable=missing-class-docstring -# TODO: get rid of all _loading_test_wrapper functions across all files -def _loading_test_wrapper(cls, data) -> Any: - """wrapper for testing the load function, which accounts for version differences""" - loaded = cls.load(data) - return loaded - - @serializable_dataclass class Config(SerializableDataclass): name: str = serializable_field(default="default_name") @@ -36,7 +28,7 @@ def test_sdc_empty(): "batch_size": 64, "__format__": "Config(SerializableDataclass)", } - recovered = _loading_test_wrapper(Config, serialized) + recovered = Config.load(serialized) assert recovered == instance @@ -50,7 +42,7 @@ def test_sdc_strip_format_jser(): "batch_size": 64, "__write_format__": "Config(SerializableDataclass)", } - recovered = _loading_test_wrapper(Config, serialized) + recovered = Config.load(serialized) assert recovered == instance @@ -73,5 +65,5 @@ class ComplicatedConfig(SerializableDataclass): def test_sdc_empty_complicated(): instance = ComplicatedConfig() serialized = instance.serialize() - recovered = _loading_test_wrapper(ComplicatedConfig, serialized) + recovered = ComplicatedConfig.load(serialized) assert recovered == instance diff --git a/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py b/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py index 35f252fc..28e5179c 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py @@ -1,7 +1,6 @@ from __future__ import annotations import sys -from typing import Any import pytest @@ -12,12 +11,6 @@ print(f"{SUPPORTS_KW_ONLY = }") -def _loading_test_wrapper(cls, data) -> Any: - """wrapper for testing the load function, which accounts for version differences""" - loaded = cls.load(data) - return loaded - - @serializable_dataclass class Person(SerializableDataclass): first_name: str @@ -50,7 +43,7 @@ def test_serialize_person(): "__format__": "Person(SerializableDataclass)", } - recovered = _loading_test_wrapper(Person, serialized) + recovered = Person.load(serialized) assert recovered == instance @@ -73,6 +66,6 @@ def test_serialize_titled_person(): "full_title": "Dr. Jane Smith", } - recovered = _loading_test_wrapper(TitledPerson, serialized) + recovered = TitledPerson.load(serialized) assert recovered == instance diff --git a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py index a8a189a2..cc67cb23 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py @@ -14,12 +14,6 @@ # pylint: disable=missing-class-docstring, unused-variable -def _loading_test_wrapper(cls, data) -> Any: - """wrapper for testing the load function, which accounts for version differences""" - loaded = cls.load(data) - return loaded - - @serializable_dataclass class BasicAutofields(SerializableDataclass): a: str @@ -134,7 +128,7 @@ def test_field_options_serialization(field_options_instance): def test_field_options_loading(field_options_instance): serialized = field_options_instance.serialize() - loaded = _loading_test_wrapper(FieldOptions, serialized) # , assert_record_len=3) + loaded = FieldOptions.load(serialized) assert loaded == field_options_instance @@ -150,7 +144,7 @@ def test_with_property_serialization(with_property_instance): def test_with_property_loading(with_property_instance): serialized = with_property_instance.serialize() - loaded = _loading_test_wrapper(WithProperty, serialized) # , assert_record_len=2) + loaded = WithProperty.load(serialized) assert loaded == with_property_instance @@ -196,7 +190,7 @@ def test_nested_serialization(person_instance): def test_nested_loading(person_instance): serialized = person_instance.serialize() - loaded = _loading_test_wrapper(Person, serialized) # , assert_record_len=6) + loaded = Person.load(serialized) assert loaded == person_instance assert loaded.address == person_instance.address @@ -219,10 +213,7 @@ def full_name(self) -> str: serialized_data = my_instance.serialize() print(serialized_data) - loaded_instance = _loading_test_wrapper( - MyClass, - serialized_data, # , assert_record_len=3 - ) + loaded_instance = MyClass.load(serialized_data) print(loaded_instance) @@ -240,7 +231,7 @@ class SimpleClass(SerializableDataclass): "__format__": "SimpleClass(SerializableDataclass)", } - loaded = _loading_test_wrapper(SimpleClass, serialized) # , assert_record_len=2) + loaded = SimpleClass.load(serialized) assert loaded == simple @@ -274,7 +265,7 @@ def full_name(self) -> str: } assert serialized == expected_ser, f"Expected {expected_ser}, got {serialized}" - loaded = _loading_test_wrapper(FullPerson, serialized) # , assert_record_len=4) + loaded = FullPerson.load(serialized) assert loaded == person @@ -293,9 +284,7 @@ class CustomSerialization(SerializableDataclass): "__format__": "CustomSerialization(SerializableDataclass)", } - loaded = _loading_test_wrapper( - CustomSerialization, serialized - ) # , assert_record_len=1) + loaded = CustomSerialization.load(serialized) assert loaded == custom @@ -343,10 +332,7 @@ def test_nested_with_container(): assert serialized == expected_ser - loaded = _loading_test_wrapper( - Nested_with_Container, - serialized, # , assert_record_len=12 - ) + loaded = Nested_with_Container.load(serialized) assert loaded == instance @@ -386,5 +372,5 @@ def test_nested_custom(recwarn): # this will send some warnings but whatever "__format__": "nested_custom(SerializableDataclass)", } assert serialized == expected_ser - loaded = _loading_test_wrapper(nested_custom, serialized) + loaded = nested_custom.load(serialized) assert loaded == instance From 9979ba98211a016ce4d8e59940757b065c2fb46c Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 16:12:01 -0700 Subject: [PATCH 129/158] fix some todos in serializable dataclasses --- .../json_serialize/serializable_dataclass.py | 44 +++++++++++++++---- .../test_sdc_defaults.py | 1 - .../test_sdc_properties_nested.py | 21 ++++++++- 3 files changed, 56 insertions(+), 10 deletions(-) diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index 61a37209..a3786c43 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -403,13 +403,12 @@ class CantGetTypeHintsWarning(UserWarning): # Step 3: Create a custom serializable_dataclass decorator -# TODO: add a kwarg for always asserting type for all fields def serializable_dataclass( # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it _cls=None, # type: ignore *, init: bool = True, - repr: bool = True, # TODO: this overrides the actual `repr` method, can this be fixed? + repr: bool = True, # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass` eq: bool = True, order: bool = False, unsafe_hash: bool = False, @@ -466,8 +465,9 @@ def serialize(self) -> dict[str, Any]: result: dict[str, Any] = { "__format__": f"{self.__class__.__name__}(SerializableDataclass)" } - + # for each field in the class for field in dataclasses.fields(self): + # need it to be our special SerializableField if not isinstance(field, SerializableField): raise ValueError( f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a SerializableField, " @@ -475,15 +475,23 @@ def serialize(self) -> dict[str, Any]: "this state should be inaccessible, please report this bug!" ) + # try to save it if field.serialize: try: + # get the val value = getattr(self, field.name) + # if it is a serializable dataclass, serialize it if isinstance(value, SerializableDataclass): value = value.serialize() + # if the value has a serialization function, use that if hasattr(value, "serialize") and callable(value.serialize): value = value.serialize() + # if the field has a serialization function, use that + # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! elif field.serialization_fn: value = field.serialization_fn(value) + + # store the value in the result result[field.name] = value except Exception as e: raise ValueError( @@ -497,17 +505,24 @@ def serialize(self) -> dict[str, Any]: ) ) from e + # store each property if we can get it for prop in self._properties_to_serialize: if hasattr(cls, prop): value = getattr(self, prop) result[prop] = value + else: + raise AttributeError( + f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" + + f"but it is in {self._properties_to_serialize = }" + + f"\n{self = }" + ) return result # mypy thinks this isnt a classmethod @classmethod # type: ignore[misc] def load(cls, data: dict[str, Any] | T) -> Type[T]: - # TODO: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ + # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ if isinstance(data, cls): return data @@ -565,7 +580,6 @@ def load(cls, data: dict[str, Any] | T) -> Type[T]: ): # if no loading function but has a type hint with a load method, use that if isinstance(value, dict): - # TODO: should this be passing the whole data dict? value = field_type_hint.load(value) else: raise ValueError( @@ -579,14 +593,23 @@ def load(cls, data: dict[str, Any] | T) -> Type[T]: ctor_kwargs[field.name] = value # validate the type - if field.assert_type: + if field.assert_type and on_type_assert in ("raise", "warn"): if field.name in ctor_kwargs: if field_type_hint is not None: - # TODO: recursive type hint checking like pydantic? try: - assert validate_type( + # validate the type + type_is_valid: bool = validate_type( ctor_kwargs[field.name], field_type_hint ) + + # if not valid, raise or warn depending on the setting in the SerializableDataclass + if not type_is_valid: + msg: str = f"Field '{field.name}' on class {cls.__name__} has type {type(ctor_kwargs[field.name])}, but expected {field_type_hint}" + if on_type_assert == "raise": + raise ValueError(msg) + else: + warnings.warn(msg) + except Exception as e: raise ValueError( f"{field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {ctor_kwargs[field.name] = }" @@ -606,6 +629,11 @@ def load(cls, data: dict[str, Any] | T) -> Type[T]: f"Field '{field.name}' on class {cls} has no type hint, but {field.assert_type = }\n{field = }\n{cls_type_hints = }\n{data = }", CantGetTypeHintsWarning, ) + else: + if on_type_assert != "ignore": + raise ValueError( + f"Invalid value for {on_type_assert = }, expected 'raise', 'warn', or 'ignore'" + ) return cls(**ctor_kwargs) diff --git a/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py b/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py index 92bce05b..8a1bf203 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_sdc_defaults.py @@ -1,7 +1,6 @@ from __future__ import annotations - from muutils.json_serialize import ( JsonSerializer, SerializableDataclass, diff --git a/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py b/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py index 28e5179c..091ef291 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_sdc_properties_nested.py @@ -32,8 +32,27 @@ def full_title(self) -> str: return f"{self.title} {self.full_name}" +@serializable_dataclass( + kw_only=SUPPORTS_KW_ONLY, + properties_to_serialize=["full_name", "not_a_real_property"], +) +class AgedPerson_not_valid(Person): + title: str + + @property + def full_title(self) -> str: + return f"{self.title} {self.full_name}" + + +def test_invalid_properties_to_serialize(): + instance = AgedPerson_not_valid(first_name="Jane", last_name="Smith", title="Dr.") + + with pytest.raises(AttributeError): + instance.serialize() + + def test_serialize_person(): - instance = Person("John", "Doe") + instance = Person(first_name="John", last_name="Doe") serialized = instance.serialize() From 64679d32bb506b445e50169b38829d67b5446f51 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 17:15:04 -0700 Subject: [PATCH 130/158] wip junk --- makefile | 88 ++++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 57 insertions(+), 31 deletions(-) diff --git a/makefile b/makefile index 07cbb8f4..53d849e6 100644 --- a/makefile +++ b/makefile @@ -34,30 +34,6 @@ COMMIT_LOG_FILE := .github/local/.commit_log # reading information and command line options # ================================================== -# reading version -# -------------------------------------------------- -# assuming your pyproject.toml has a line that looks like `version = "0.0.1"`, will get the version -VERSION := $(shell python -c "import re; print(re.search(r'^version\s*=\s*\"(.+?)\"', open('$(PYPROJECT)').read(), re.MULTILINE).group(1))") -# read last auto-uploaded version from file -LAST_VERSION := $(shell [ -f $(LAST_VERSION_FILE) ] && cat $(LAST_VERSION_FILE) || echo NONE) - - -# getting commit log -# -------------------------------------------------- -# note that the commands at the end: -# 1) format the git log -# 2) replace backticks with single quotes, to avoid funny business -# 3) add a final newline, to make tac happy -# 4) reverse the order of the lines, so that the oldest commit is first -# 5) replace newlines with tabs, to prevent the newlines from being lost -ifeq ($(LAST_VERSION),NONE) - COMMIT_LOG_SINCE_LAST_VERSION := "No last version found, cannot generate commit log" -else - COMMIT_LOG_SINCE_LAST_VERSION := $(shell (git log $(LAST_VERSION)..HEAD --pretty=format:"- %s (%h)" | tr '`' "'" ; echo) | tac | tr '\n' '\t') -# 1 2 3 4 5 -endif - - # RUN_GLOBAL=1 to use global `PYTHON_BASE` instead of `poetry run $(PYTHON_BASE)` # -------------------------------------------------- # for formatting, we might want to run python without setting up all of poetry @@ -68,14 +44,64 @@ else PYTHON = $(PYTHON_BASE) endif -# get the python version now that we have picked the python command +# reading version +# -------------------------------------------------- +# assuming your pyproject.toml has a line that looks like `version = "0.0.1"`, will get the version +VERSION := NULL +# read last auto-uploaded version from file +LAST_VERSION := NULL +# get the python version, now that we have picked the python command +PYTHON_VERSION := NULL +.PHONY: gen-version-info +gen-version-info: + $(eval VERSION := $(shell python -c "import re; print(re.search(r'^version\s*=\s*\"(.+?)\"', open('$(PYPROJECT)').read(), re.MULTILINE).group(1))") ) + $(eval LAST_VERSION := $(shell [ -f $(LAST_VERSION_FILE) ] && cat $(LAST_VERSION_FILE) || echo NULL) ) + $(eval PYTHON_VERSION := $(shell $(PYTHON) -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}')") ) + +# getting commit log +# note that if gen-version-info has not been run, this will not work # -------------------------------------------------- -PYTHON_VERSION := $(shell $(PYTHON) -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}')") +# explanation of the commit log generation: +# 1) in the shell 2) get the git log 3) since the last version +# 4) format the git log +# 5) replace backticks with single quotes, to avoid funny business +# 6) add a final newline, to make tac happy +# 7) reverse the order of the lines, so that the oldest commit is first +# 8) replace newlines with tabs, to prevent the newlines from being lost +# $(shell (git log $(LAST_VERSION)..HEAD --pretty=format:"- %s (%h)" | tr '`' "'" ; echo) | tac | tr '\n' '\t') +# 1 2 3 4 5 6 7 8 +COMMIT_LOG_SINCE_LAST_VERSION := NULL +COMMIT_LOG_TEMP := NULL +.PHONY: gen-commit-log +gen-commit-log: gen-version-info + @echo "Generating commit log since last version" + @echo "Current version is $(VERSION), last auto-uploaded version is $(LAST_VERSION)" + if [ "$(LAST_VERSION)" = "NULL" ]; then \ + echo "LAST_VERSION is NULL, cant get commit log!"; \ + exit 1; \ + fi + @echo "Getting commit log since last version $(LAST_VERSION)" + $(eval COMMIT_LOG_TEMP := $(shell python -c "import subprocess, re; log=subprocess.check_output(['git', 'log', '$(LAST_VERSION)..HEAD', '--pretty=format:- %s (%h)']).decode('utf-8'); log=re.sub(r'[`()]', lambda m: '\\'+m.group(0), log); print('\\t'.join(log.split('\\n')[::-1]))")) + + @echo "Commit log temp:" + @echo $(COMMIT_LOG_TEMP) + $(eval COMMIT_LOG_SINCE_LAST_VERSION := $(shell (git log $(LAST_VERSION)..HEAD --pretty=format:"- %s (%h)" | tr '`' "'" ; echo) | tac | tr '\n' '\t')) + @echo "Commit log since last version:" + @echo $(COMMIT_LOG_SINCE_LAST_VERSION) + +# $(shell git log $(LAST_VERSION)..HEAD --pretty=format:"- %s (%h)") +# $(eval COMMIT_LOG_SINCE_LAST_VERSION := ) + + + + + # looser typing, allow warnings for python <3.10 # -------------------------------------------------- -COMPATIBILITY_MODE := $(shell $(PYTHON) -c "import sys; print(1 if sys.version_info < (3, 10) else 0)") TYPECHECK_ARGS ?= +# COMPATIBILITY_MODE: whether to run in compatibility mode for python <3.10 +COMPATIBILITY_MODE := $(shell $(PYTHON) -c "import sys; print(1 if sys.version_info < (3, 10) else 0)") # options we might want to pass to pytest # -------------------------------------------------- @@ -102,7 +128,7 @@ endif # Update the PYTEST_OPTIONS to include the conditional ignore option ifeq ($(COMPATIBILITY_MODE), 1) - JUNK := $(info WARNING: Detected python version less than 3.10, some behavior will be different) + JUNK := $(info !!! WARNING !!!: Detected python version less than 3.10, some behavior will be different) PYTEST_OPTIONS += --ignore=tests/unit/validate_type/ TYPECHECK_ARGS += --disable-error-code misc --disable-error-code syntax --disable-error-code import-not-found endif @@ -115,7 +141,7 @@ endif default: help .PHONY: version -version: +version: gen-version-info gen-commit-log @echo "Current version is $(VERSION), last auto-uploaded version is $(LAST_VERSION)" @echo "Commit log since last version:" @echo "$(COMMIT_LOG_SINCE_LAST_VERSION)" | tr '\t' '\n' > $(COMMIT_LOG_FILE) @@ -228,7 +254,7 @@ build: poetry build .PHONY: publish -publish: check build verify-git version +publish: gen-version-info check build verify-git version @echo "run all checks, build, and then publish" @echo "Enter the new version number if you want to upload to pypi and create a new tag" @@ -271,7 +297,7 @@ clean: # listing targets, from stackoverflow # https://stackoverflow.com/questions/4219255/how-do-you-get-the-list-of-targets-in-a-makefile .PHONY: help -help: +help: gen-version-info @echo -n "list make targets" @echo ":" @cat Makefile | sed -n '/^\.PHONY: / h; /\(^\t@*echo\|^\t:\)/ {H; x; /PHONY/ s/.PHONY: \(.*\)\n.*"\(.*\)"/ make \1\t\2/p; d; x}'| sort -k2,2 |expand -t 25 From 1c808639f976b55abb16e9f2242ba6eb1318419b Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 17:25:58 -0700 Subject: [PATCH 131/158] commit log from python, much cleaner --- makefile | 38 ++++---------------------------------- 1 file changed, 4 insertions(+), 34 deletions(-) diff --git a/makefile b/makefile index 53d849e6..233fa9e2 100644 --- a/makefile +++ b/makefile @@ -59,45 +59,16 @@ gen-version-info: $(eval PYTHON_VERSION := $(shell $(PYTHON) -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}')") ) # getting commit log -# note that if gen-version-info has not been run, this will not work -# -------------------------------------------------- -# explanation of the commit log generation: -# 1) in the shell 2) get the git log 3) since the last version -# 4) format the git log -# 5) replace backticks with single quotes, to avoid funny business -# 6) add a final newline, to make tac happy -# 7) reverse the order of the lines, so that the oldest commit is first -# 8) replace newlines with tabs, to prevent the newlines from being lost -# $(shell (git log $(LAST_VERSION)..HEAD --pretty=format:"- %s (%h)" | tr '`' "'" ; echo) | tac | tr '\n' '\t') -# 1 2 3 4 5 6 7 8 -COMMIT_LOG_SINCE_LAST_VERSION := NULL -COMMIT_LOG_TEMP := NULL .PHONY: gen-commit-log gen-commit-log: gen-version-info - @echo "Generating commit log since last version" - @echo "Current version is $(VERSION), last auto-uploaded version is $(LAST_VERSION)" if [ "$(LAST_VERSION)" = "NULL" ]; then \ echo "LAST_VERSION is NULL, cant get commit log!"; \ exit 1; \ fi - @echo "Getting commit log since last version $(LAST_VERSION)" - $(eval COMMIT_LOG_TEMP := $(shell python -c "import subprocess, re; log=subprocess.check_output(['git', 'log', '$(LAST_VERSION)..HEAD', '--pretty=format:- %s (%h)']).decode('utf-8'); log=re.sub(r'[`()]', lambda m: '\\'+m.group(0), log); print('\\t'.join(log.split('\\n')[::-1]))")) - - - @echo "Commit log temp:" - @echo $(COMMIT_LOG_TEMP) - $(eval COMMIT_LOG_SINCE_LAST_VERSION := $(shell (git log $(LAST_VERSION)..HEAD --pretty=format:"- %s (%h)" | tr '`' "'" ; echo) | tac | tr '\n' '\t')) - @echo "Commit log since last version:" - @echo $(COMMIT_LOG_SINCE_LAST_VERSION) + $(shell python -c "import subprocess; open('$(COMMIT_LOG_FILE)', 'w').write('\n'.join(reversed(subprocess.check_output(['git', 'log', '$(LAST_VERSION)'.strip() + '..HEAD', '--pretty=format:- %s (%h)']).decode('utf-8').strip().split('\n'))))") -# $(shell git log $(LAST_VERSION)..HEAD --pretty=format:"- %s (%h)") -# $(eval COMMIT_LOG_SINCE_LAST_VERSION := ) - - - - -# looser typing, allow warnings for python <3.10 +# loose typing, allow warnings for python <3.10 # -------------------------------------------------- TYPECHECK_ARGS ?= # COMPATIBILITY_MODE: whether to run in compatibility mode for python <3.10 @@ -141,10 +112,9 @@ endif default: help .PHONY: version -version: gen-version-info gen-commit-log +version: gen-commit-log @echo "Current version is $(VERSION), last auto-uploaded version is $(LAST_VERSION)" @echo "Commit log since last version:" - @echo "$(COMMIT_LOG_SINCE_LAST_VERSION)" | tr '\t' '\n' > $(COMMIT_LOG_FILE) @cat $(COMMIT_LOG_FILE) @if [ "$(VERSION)" = "$(LAST_VERSION)" ]; then \ echo "Python package $(VERSION) is the same as last published version $(LAST_VERSION), exiting!"; \ @@ -254,7 +224,7 @@ build: poetry build .PHONY: publish -publish: gen-version-info check build verify-git version +publish: gen-commit-log check build verify-git version @echo "run all checks, build, and then publish" @echo "Enter the new version number if you want to upload to pypi and create a new tag" From 06a348b0f097fc3282ef761e20b9bf9c88c17885 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 17:27:06 -0700 Subject: [PATCH 132/158] trying py 3.12 --- .github/workflows/checks.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 684cbe1c..ddf6b27d 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -81,8 +81,8 @@ jobs: torch: '2.3.1' - python: '3.11' torch: '2.3.1' - # - python: '3.12' - # torch: '2.3.1' + - python: '3.12' + torch: '2.3.1' steps: - name: Checkout code uses: actions/checkout@v4 From a33337c9e5b99624bfdc37e6493d8d595d631b28 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 17:36:25 -0700 Subject: [PATCH 133/158] numpy deps for python 3.8 to 3.12 are complicated --- .github/dev-requirements.txt | 3 +- poetry.lock | 118 +++++++++++++++-------------------- pyproject.toml | 5 +- 3 files changed, 58 insertions(+), 68 deletions(-) diff --git a/.github/dev-requirements.txt b/.github/dev-requirements.txt index 18b678be..7cda1761 100644 --- a/.github/dev-requirements.txt +++ b/.github/dev-requirements.txt @@ -24,7 +24,8 @@ matplotlib==3.7.5 ; python_version >= "3.8" and python_version < "4.0" mdurl==0.1.2 ; python_version >= "3.8" and python_version < "4" mypy-extensions==1.0.0 ; python_version >= "3.8" and python_version < "4.0" mypy==1.10.0 ; python_version >= "3.8" and python_version < "4.0" -numpy==1.24.4 ; python_version >= "3.8" and python_version < "4.0" +numpy==1.24.4 ; python_version >= "3.8" and python_version < "3.9" +numpy==1.26.4 ; python_version >= "3.9" and python_version < "4.0" packaging==24.1 ; python_version >= "3.8" and python_version < "4.0" parso==0.8.4 ; python_version >= "3.10" and python_version < "4.0" pathspec==0.12.1 ; python_version >= "3.8" and python_version < "4" diff --git a/poetry.lock b/poetry.lock index 2af38d47..cda5808f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -43,68 +43,6 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] -[[package]] -name = "contourpy" -version = "1.1.0" -description = "Python library for calculating contours of 2D quadrilateral grids" -optional = false -python-versions = ">=3.8" -files = [ - {file = "contourpy-1.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:89f06eff3ce2f4b3eb24c1055a26981bffe4e7264acd86f15b97e40530b794bc"}, - {file = "contourpy-1.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:dffcc2ddec1782dd2f2ce1ef16f070861af4fb78c69862ce0aab801495dda6a3"}, - {file = "contourpy-1.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25ae46595e22f93592d39a7eac3d638cda552c3e1160255258b695f7b58e5655"}, - {file = "contourpy-1.1.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:17cfaf5ec9862bc93af1ec1f302457371c34e688fbd381f4035a06cd47324f48"}, - {file = "contourpy-1.1.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:18a64814ae7bce73925131381603fff0116e2df25230dfc80d6d690aa6e20b37"}, - {file = "contourpy-1.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90c81f22b4f572f8a2110b0b741bb64e5a6427e0a198b2cdc1fbaf85f352a3aa"}, - {file = "contourpy-1.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:53cc3a40635abedbec7f1bde60f8c189c49e84ac180c665f2cd7c162cc454baa"}, - {file = "contourpy-1.1.0-cp310-cp310-win32.whl", hash = "sha256:9b2dd2ca3ac561aceef4c7c13ba654aaa404cf885b187427760d7f7d4c57cff8"}, - {file = "contourpy-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:1f795597073b09d631782e7245016a4323cf1cf0b4e06eef7ea6627e06a37ff2"}, - {file = "contourpy-1.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0b7b04ed0961647691cfe5d82115dd072af7ce8846d31a5fac6c142dcce8b882"}, - {file = "contourpy-1.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:27bc79200c742f9746d7dd51a734ee326a292d77e7d94c8af6e08d1e6c15d545"}, - {file = "contourpy-1.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:052cc634bf903c604ef1a00a5aa093c54f81a2612faedaa43295809ffdde885e"}, - {file = "contourpy-1.1.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9382a1c0bc46230fb881c36229bfa23d8c303b889b788b939365578d762b5c18"}, - {file = "contourpy-1.1.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e5cec36c5090e75a9ac9dbd0ff4a8cf7cecd60f1b6dc23a374c7d980a1cd710e"}, - {file = "contourpy-1.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f0cbd657e9bde94cd0e33aa7df94fb73c1ab7799378d3b3f902eb8eb2e04a3a"}, - {file = "contourpy-1.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:181cbace49874f4358e2929aaf7ba84006acb76694102e88dd15af861996c16e"}, - {file = "contourpy-1.1.0-cp311-cp311-win32.whl", hash = "sha256:edb989d31065b1acef3828a3688f88b2abb799a7db891c9e282df5ec7e46221b"}, - {file = "contourpy-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:fb3b7d9e6243bfa1efb93ccfe64ec610d85cfe5aec2c25f97fbbd2e58b531256"}, - {file = "contourpy-1.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bcb41692aa09aeb19c7c213411854402f29f6613845ad2453d30bf421fe68fed"}, - {file = "contourpy-1.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5d123a5bc63cd34c27ff9c7ac1cd978909e9c71da12e05be0231c608048bb2ae"}, - {file = "contourpy-1.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:62013a2cf68abc80dadfd2307299bfa8f5aa0dcaec5b2954caeb5fa094171103"}, - {file = "contourpy-1.1.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0b6616375d7de55797d7a66ee7d087efe27f03d336c27cf1f32c02b8c1a5ac70"}, - {file = "contourpy-1.1.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:317267d915490d1e84577924bd61ba71bf8681a30e0d6c545f577363157e5e94"}, - {file = "contourpy-1.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d551f3a442655f3dcc1285723f9acd646ca5858834efeab4598d706206b09c9f"}, - {file = "contourpy-1.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e7a117ce7df5a938fe035cad481b0189049e8d92433b4b33aa7fc609344aafa1"}, - {file = "contourpy-1.1.0-cp38-cp38-win32.whl", hash = "sha256:108dfb5b3e731046a96c60bdc46a1a0ebee0760418951abecbe0fc07b5b93b27"}, - {file = "contourpy-1.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:d4f26b25b4f86087e7d75e63212756c38546e70f2a92d2be44f80114826e1cd4"}, - {file = "contourpy-1.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bc00bb4225d57bff7ebb634646c0ee2a1298402ec10a5fe7af79df9a51c1bfd9"}, - {file = "contourpy-1.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:189ceb1525eb0655ab8487a9a9c41f42a73ba52d6789754788d1883fb06b2d8a"}, - {file = "contourpy-1.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9f2931ed4741f98f74b410b16e5213f71dcccee67518970c42f64153ea9313b9"}, - {file = "contourpy-1.1.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:30f511c05fab7f12e0b1b7730ebdc2ec8deedcfb505bc27eb570ff47c51a8f15"}, - {file = "contourpy-1.1.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:143dde50520a9f90e4a2703f367cf8ec96a73042b72e68fcd184e1279962eb6f"}, - {file = "contourpy-1.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e94bef2580e25b5fdb183bf98a2faa2adc5b638736b2c0a4da98691da641316a"}, - {file = "contourpy-1.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ed614aea8462735e7d70141374bd7650afd1c3f3cb0c2dbbcbe44e14331bf002"}, - {file = "contourpy-1.1.0-cp39-cp39-win32.whl", hash = "sha256:71551f9520f008b2950bef5f16b0e3587506ef4f23c734b71ffb7b89f8721999"}, - {file = "contourpy-1.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:438ba416d02f82b692e371858143970ed2eb6337d9cdbbede0d8ad9f3d7dd17d"}, - {file = "contourpy-1.1.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a698c6a7a432789e587168573a864a7ea374c6be8d4f31f9d87c001d5a843493"}, - {file = "contourpy-1.1.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:397b0ac8a12880412da3551a8cb5a187d3298a72802b45a3bd1805e204ad8439"}, - {file = "contourpy-1.1.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:a67259c2b493b00e5a4d0f7bfae51fb4b3371395e47d079a4446e9b0f4d70e76"}, - {file = "contourpy-1.1.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:2b836d22bd2c7bb2700348e4521b25e077255ebb6ab68e351ab5aa91ca27e027"}, - {file = "contourpy-1.1.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:084eaa568400cfaf7179b847ac871582199b1b44d5699198e9602ecbbb5f6104"}, - {file = "contourpy-1.1.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:911ff4fd53e26b019f898f32db0d4956c9d227d51338fb3b03ec72ff0084ee5f"}, - {file = "contourpy-1.1.0.tar.gz", hash = "sha256:e53046c3863828d21d531cc3b53786e6580eb1ba02477e8681009b6aa0870b21"}, -] - -[package.dependencies] -numpy = ">=1.16" - -[package.extras] -bokeh = ["bokeh", "selenium"] -docs = ["furo", "sphinx-copybutton"] -mypy = ["contourpy[bokeh,docs]", "docutils-stubs", "mypy (==1.2.0)", "types-Pillow"] -test = ["Pillow", "contourpy[test-no-images]", "matplotlib"] -test-no-images = ["pytest", "pytest-cov", "wurlitzer"] - [[package]] name = "contourpy" version = "1.1.1" @@ -167,7 +105,10 @@ files = [ ] [package.dependencies] -numpy = {version = ">=1.16,<2.0", markers = "python_version <= \"3.11\""} +numpy = [ + {version = ">=1.16,<2.0", markers = "python_version <= \"3.11\""}, + {version = ">=1.26.0rc1,<2.0", markers = "python_version >= \"3.12\""}, +] [package.extras] bokeh = ["bokeh", "selenium"] @@ -1078,6 +1019,51 @@ files = [ {file = "numpy-1.24.4.tar.gz", hash = "sha256:80f5e3a4e498641401868df4208b74581206afbee7cf7b8329daae82676d9463"}, ] +[[package]] +name = "numpy" +version = "1.26.4" +description = "Fundamental package for array computing in Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"}, + {file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"}, + {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"}, + {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"}, + {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"}, + {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"}, + {file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"}, + {file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"}, + {file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"}, + {file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"}, + {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"}, + {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"}, + {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"}, + {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"}, + {file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"}, + {file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"}, + {file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"}, + {file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"}, + {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"}, + {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"}, + {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"}, + {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"}, + {file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"}, + {file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"}, + {file = "numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c"}, + {file = "numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be"}, + {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764"}, + {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3"}, + {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd"}, + {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c"}, + {file = "numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6"}, + {file = "numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"}, + {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"}, +] + [[package]] name = "nvidia-cublas-cu12" version = "12.1.3.1" @@ -1933,12 +1919,12 @@ doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linke test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [extras] -array = ["jaxtyping", "numpy", "torch"] -array-no-torch = ["jaxtyping", "numpy"] +array = ["jaxtyping", "numpy", "numpy", "torch"] +array-no-torch = ["jaxtyping", "numpy", "numpy"] notebook = ["ipython"] zanj = ["zanj"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "ad931582fd4069a70774c8f616036c92cc008e0bacde72b07e2260f169fdb905" +content-hash = "ba5985c876532a581082f5dbc55e3ca2237cf82f08044181d55b19a001b00371" diff --git a/pyproject.toml b/pyproject.toml index 9533aed6..99af05b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,10 @@ repository = "https://github.com/mivanit/muutils" [tool.poetry.dependencies] python = "^3.8" -numpy = { version = "^1.22.4", optional = true } +numpy = [ + { version = "^1.24.4", optional = true, markers = "python_version < '3.9'" }, + { version = "^1.26.4", optional = true, markers = "python_version >= '3.9'" }, +] torch = { version = ">=1.13.1", optional = true } jaxtyping = { version = "^0.2.12", optional = true } ipython = { version = "^8.20.0", optional = true, python = "^3.10" } From 19128f1c16677ef89d8533a41a434cc4702d1fc8 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 17:39:32 -0700 Subject: [PATCH 134/158] updated setup-python to v5 --- .github/workflows/checks.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index ddf6b27d..2746e08a 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -34,7 +34,7 @@ jobs: fetch-depth: 0 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.10' @@ -90,7 +90,7 @@ jobs: fetch-depth: 0 - name: Set up python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.versions.python }} From 96226dd458ce28a358ff881269485fa3e3f55d78 Mon Sep 17 00:00:00 2001 From: mivanit Date: Wed, 19 Jun 2024 17:47:03 -0700 Subject: [PATCH 135/158] comments & classifiers in pyproject.toml --- pyproject.toml | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 99af05b0..e293d6f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "muutils" version = "0.5.13" -description = "A collection of miscellaneous python utilities" +description = "miscellaneous python utilities" license = "GPL-3.0-only" authors = ["mivanit "] readme = "README.md" @@ -14,25 +14,33 @@ classifiers=[ "Development Status :: 4 - Beta", "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", "Operating System :: OS Independent", + "Topic :: Utilities", + "Typing :: Typed", ] repository = "https://github.com/mivanit/muutils" [tool.poetry.dependencies] python = "^3.8" +# [array] numpy = [ { version = "^1.24.4", optional = true, markers = "python_version < '3.9'" }, { version = "^1.26.4", optional = true, markers = "python_version >= '3.9'" }, ] torch = { version = ">=1.13.1", optional = true } jaxtyping = { version = "^0.2.12", optional = true } +# [notebook] ipython = { version = "^8.20.0", optional = true, python = "^3.10" } +# [zanj] zanj = { version = "^0.2.2", optional = true, python = "^3.10" } [tool.poetry.group.dev.dependencies] -pytest = "^8.2.2" +# typing mypy = "^1.0.1" +# tests & coverage +pytest = "^8.2.2" pytest-cov = "^4.1.0" coverage-badge = "^1.1.0" +# for testing plotting matplotlib = "^3.0.0" plotly = "^5.0.0" @@ -42,6 +50,7 @@ ruff = "^0.4.8" [tool.poetry.extras] array = ["numpy", "torch", "jaxtyping"] +# special group for CI, where we install cpu torch separately array_no_torch = ["numpy", "jaxtyping"] notebook = ["ipython"] zanj = ["zanj"] From 6be2e41b04de69304d7d69fee94095fcaa824272 Mon Sep 17 00:00:00 2001 From: mivanit Date: Thu, 20 Jun 2024 15:06:56 -0700 Subject: [PATCH 136/158] wip --- .../json_serialize/serializable_dataclass.py | 222 ++++++++++++------ muutils/validate_type.py | 15 ++ .../unit/validate_type/test_validate_type.py | 7 + 3 files changed, 178 insertions(+), 66 deletions(-) diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index a3786c43..032d2681 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -2,6 +2,7 @@ import abc import dataclasses +import functools import json import sys import types @@ -315,6 +316,112 @@ def zanj_register_loader_serializable_dataclass(cls: Type[T]): return lh +OnTypeAssertDo = typing.Literal["raise", "warn", "ignore"] + +DEFAULT_ON_TYPE_ASSERT: OnTypeAssertDo = "warn" + +def SerializableDataclass__validate_field_type( + self: SerializableDataclass, + field: SerializableField|str, + on_type_assert: OnTypeAssertDo = DEFAULT_ON_TYPE_ASSERT, +) -> bool: + # do nothing + if not field.assert_type: + return True + + if on_type_assert == "ignore": + return True + + # get field + if isinstance(field, str): + field = self.__dataclass_fields__[field] + + assert isinstance( + field, SerializableField + ), f"Field '{field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(field) = }" + + # get field type hints + field_type_hint: Any = get_cls_type_hints(self.__class__).get(field.name, None) + + # get the value + value: Any = getattr(self, field.name) + + # validate the type + if field_type_hint is not None: + try: + # validate the type + type_is_valid: bool = validate_type( + value, field_type_hint + ) + + # if not valid, raise or warn depending on the setting in the SerializableDataclass + if not type_is_valid: + msg: str = f"Field '{field.name}' on class {self.__class__.__name__} has type {type(value)}, but expected {field_type_hint}" + if on_type_assert == "raise": + raise ValueError(msg) + else: + warnings.warn(msg) + + except Exception as e: + raise ValueError( + "exception while validating type: " + + f"{field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }" + ) from e + else: + raise ValueError( + f"Cannot get type hints for {self.__class__.__name__}, field {field.name = } and so cannot validate." + + f"Python version is {sys.version_info = }. You can:\n" + + f" - disable `assert_type`. Currently: {field.assert_type = }\n" + + f" - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {field.type = }\n" + + " - use python 3.9.x or higher\n" + + " - coming in a future release, specify custom type validation functions\n" + ) + + + +def SerializableDataclass__validate_fields_types(self: SerializableDataclass, on_type_assert: OnTypeAssertDo = DEFAULT_ON_TYPE_ASSERT) -> bool: + """validate the types of the fields on a SerializableDataclass""" + + # arg validation + if on_type_assert not in ("raise", "warn", "ignore"): + raise ValueError( + f"Invalid value for {on_type_assert = }, expected 'raise', 'warn', or 'ignore'" + ) + + # do nothing if ignore + if on_type_assert == "ignore": + return + + # if except, bundle the exceptions + results: dict[str, bool] = dict() + exceptions: dict[str, Exception] = dict() + + # for each field in the class + cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self) + for field in cls_fields: + try: + assert self.validate_field_type(field, on_type_assert) + except Exception as e: + exceptions[field.name] = e + + # figure out what to do with the exceptions + if len(exceptions) > 0: + if on_type_assert in ("warn", "ignore"): + msg: str = ( + f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}" + + f"\n\t" + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]) + ) + if on_type_assert == "warn": + warnings.warn(msg) + else: + raise ValueError(msg) from exceptions[0] + else: + assert on_type_assert == "ignore" + + return True + + + class SerializableDataclass(abc.ABC): """Base class for serializable dataclasses @@ -323,11 +430,17 @@ class SerializableDataclass(abc.ABC): """ def serialize(self) -> dict[str, Any]: - raise NotImplementedError + raise NotImplementedError(f"decorate {self.__class__ = } with `@serializable_dataclass`") @classmethod def load(cls: Type[T], data: dict[str, Any] | T) -> T: - raise NotImplementedError + raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`") + + def validate_fields_types(self, on_type_assert: OnTypeAssertDo = DEFAULT_ON_TYPE_ASSERT) -> bool: + return SerializableDataclass__validate_fields_types(self, on_type_assert) + + def validate_field_type(self, field: SerializableField|str, on_type_assert: OnTypeAssertDo = DEFAULT_ON_TYPE_ASSERT) -> bool: + return SerializableDataclass__validate_field_type(self, field, on_type_assert) def __eq__(self, other: Any) -> bool: return dc_eq(self, other) @@ -402,6 +515,35 @@ class CantGetTypeHintsWarning(UserWarning): pass +# cache this so we don't have to keep getting it +@functools.lru_cache(typed=True) +def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]: + "cached typing.get_type_hints for a class" + # get the type hints for the class + cls_type_hints: dict[str, Any] + try: + cls_type_hints = typing.get_type_hints(cls) + except TypeError as e: + if sys.version_info < (3, 9): + warnings.warn( + f"Cannot get type hints for {cls.__name__}. Python version is {sys.version_info = }. You can:\n" + + " - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x)\n" + + " - use python 3.9.x or higher\n" + + " - add explicit loading functions to the fields\n" + + f" {dataclasses.fields(cls) = }", + CantGetTypeHintsWarning, + ) + cls_type_hints = dict() + else: + raise TypeError( + f"Cannot get type hints for {cls.__name__}. Python version is {sys.version_info = }\n" + + f" {dataclasses.fields(cls) = }\n" + + f" {e = }" + ) from e + + return cls_type_hints + + # Step 3: Create a custom serializable_dataclass decorator def serializable_dataclass( # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it @@ -530,27 +672,7 @@ def load(cls, data: dict[str, Any] | T) -> Type[T]: data, typing.Mapping ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" - # get the type hints for the class - cls_type_hints: dict[str, Any] - try: - cls_type_hints = typing.get_type_hints(cls) - except TypeError as e: - if sys.version_info < (3, 9): - warnings.warn( - f"Cannot get type hints for {cls.__name__}. Python version is {sys.version_info = }. You can:\n" - + " - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x)\n" - + " - use python 3.9.x or higher\n" - + " - add explicit loading functions to the fields\n" - + f" {dataclasses.fields(cls) = }", - CantGetTypeHintsWarning, - ) - cls_type_hints = dict() - else: - raise TypeError( - f"Cannot get type hints for {cls.__name__}. Python version is {sys.version_info = }\n" - + f" {dataclasses.fields(cls) = }\n" - + f" {e = }" - ) from e + cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) # initialize dict for keeping what we will pass to the constructor ctor_kwargs: dict[str, Any] = dict() @@ -592,57 +714,25 @@ def load(cls, data: dict[str, Any] | T) -> Type[T]: # store the value in the constructor kwargs ctor_kwargs[field.name] = value - # validate the type - if field.assert_type and on_type_assert in ("raise", "warn"): - if field.name in ctor_kwargs: - if field_type_hint is not None: - try: - # validate the type - type_is_valid: bool = validate_type( - ctor_kwargs[field.name], field_type_hint - ) - - # if not valid, raise or warn depending on the setting in the SerializableDataclass - if not type_is_valid: - msg: str = f"Field '{field.name}' on class {cls.__name__} has type {type(ctor_kwargs[field.name])}, but expected {field_type_hint}" - if on_type_assert == "raise": - raise ValueError(msg) - else: - warnings.warn(msg) - - except Exception as e: - raise ValueError( - f"{field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {ctor_kwargs[field.name] = }" - ) from e - else: - raise ValueError( - f"Cannot get type hints for {cls.__name__}, field {field.name = } and so cannot validate." - + f"Python version is {sys.version_info = }. You can:\n" - + f" - disable `assert_type`. Currently: {field.assert_type = }\n" - + f" - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {field.type = }\n" - + " - use python 3.9.x or higher\n" - + " - coming in a future release, specify custom type validation functions\n" - ) - else: - # TODO: raise an exception here? Can't validate if data given - warnings.warn( - f"Field '{field.name}' on class {cls} has no type hint, but {field.assert_type = }\n{field = }\n{cls_type_hints = }\n{data = }", - CantGetTypeHintsWarning, - ) - else: - if on_type_assert != "ignore": - raise ValueError( - f"Invalid value for {on_type_assert = }, expected 'raise', 'warn', or 'ignore'" - ) + # create a new instance of the class with the constructor kwargs + output: cls = cls(**ctor_kwargs) + + # validate the types of the fields if needed + if on_type_assert in ("raise", "warn"): + output.validate_fields_types() - return cls(**ctor_kwargs) + # return the new instance + return output # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments # type is `Callable[[T], dict]` cls.serialize = serialize # type: ignore[attr-defined] # type is `Callable[[dict], T]` cls.load = load # type: ignore[attr-defined] + # type is `Callable[[T, OnTypeAssertDo], bool]` + cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined] + # type is `Callable[[T, T], bool]` cls.__eq__ = lambda self, other: dc_eq(self, other) # type: ignore[assignment] # Register the class with ZANJ diff --git a/muutils/validate_type.py b/muutils/validate_type.py index e8b3de23..11fb6fcd 100644 --- a/muutils/validate_type.py +++ b/muutils/validate_type.py @@ -119,6 +119,21 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: # check all items in tuple are of the correct type return all(validate_type(item, arg) for item, arg in zip(value, args)) + if origin is type: + # no args + if len(args) == 0: + return isinstance(value, type) + # incorrect number of args + if len(args) != 1: + raise TypeError( + f"Expected 1 argument for Type, got {args = }, {expected_type = }, {value = }, {origin = }", + f"{GenericAliasTypes = }", + ) + # check is type + if origin is type: + return True + + # TODO: Callables, etc. raise ValueError( diff --git a/tests/unit/validate_type/test_validate_type.py b/tests/unit/validate_type/test_validate_type.py index 8784b788..769f85ca 100644 --- a/tests/unit/validate_type/test_validate_type.py +++ b/tests/unit/validate_type/test_validate_type.py @@ -395,6 +395,13 @@ def test_validate_type_complex(): ) +def test_validate_type_class(): + class Test: + def __init__(self, a: int, b: str): + self.a: int = a + self.b: str = b + + @pytest.mark.parametrize( "value, expected_type, expected_result", [ From 228e1817a8b16d8ee6f1b0e7685009470afafc70 Mon Sep 17 00:00:00 2001 From: mivanit Date: Thu, 20 Jun 2024 15:11:36 -0700 Subject: [PATCH 137/158] reorg --- .../json_serialize/serializable_dataclass.py | 258 +----------------- muutils/json_serialize/serializable_field.py | 139 ++++++++++ muutils/json_serialize/util.py | 132 ++++++++- 3 files changed, 273 insertions(+), 256 deletions(-) create mode 100644 muutils/json_serialize/serializable_field.py diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index 032d2681..c267bd54 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -11,263 +11,11 @@ from typing import Any, Callable, Optional, Type, TypeVar, Union from muutils.validate_type import validate_type +from muutils.json_serialize.serializable_field import SerializableField, serializable_field +from muutils.json_serialize.util import array_safe_eq, dc_eq # pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access - -class SerializableField(dataclasses.Field): - """extension of `dataclasses.Field` with additional serialization properties""" - - __slots__ = ( - # from dataclasses.Field.__slots__ - "name", - "type", - "default", - "default_factory", - "repr", - "hash", - "init", - "compare", - "metadata", - "kw_only", - "_field_type", # Private: not to be used by user code. - # new ones - "serialize", - "serialization_fn", - "loading_fn", - "assert_type", - ) - - def __init__( - self, - default: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING, - default_factory: Union[ - Callable[[], Any], dataclasses._MISSING_TYPE - ] = dataclasses.MISSING, - init: bool = True, - repr: bool = True, - hash: Optional[bool] = None, - compare: bool = True, - # TODO: add field for custom comparator (such as serializing) - metadata: Optional[types.MappingProxyType] = None, - kw_only: Union[bool, dataclasses._MISSING_TYPE] = dataclasses.MISSING, - serialize: bool = True, - serialization_fn: Optional[Callable[[Any], Any]] = None, - loading_fn: Optional[Callable[[Any], Any]] = None, - assert_type: bool = True, - # TODO: add field for custom type assertion - ): - # TODO: should we do this check, or assume the user knows what they are doing? - if init and not serialize: - raise ValueError("Cannot have init=True and serialize=False") - - # need to assemble kwargs in this hacky way so as not to upset type checking - super_kwargs: dict[str, Any] = dict( - default=default, - default_factory=default_factory, - init=init, - repr=repr, - hash=hash, - compare=compare, - kw_only=kw_only, - ) - - if metadata is not None: - super_kwargs["metadata"] = metadata - else: - super_kwargs["metadata"] = types.MappingProxyType({}) - - # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy - if sys.version_info < (3, 10): - if super_kwargs["kw_only"] == True: # noqa: E712 - raise ValueError("kw_only is not supported in python >=3.9") - else: - del super_kwargs["kw_only"] - - # actually init the super class - super().__init__(**super_kwargs) # type: ignore[call-arg] - - # now init the new fields - self.serialize: bool = serialize - self.serialization_fn: Optional[Callable[[Any], Any]] = serialization_fn - self.loading_fn: Optional[Callable[[Any], Any]] = loading_fn - self.assert_type: bool = assert_type - - @classmethod - def from_Field(cls, field: dataclasses.Field) -> "SerializableField": - """copy all values from a `dataclasses.Field` to new `SerializableField`""" - return cls( - default=field.default, - default_factory=field.default_factory, - init=field.init, - repr=field.repr, - hash=field.hash, - compare=field.compare, - metadata=field.metadata, - kw_only=getattr(field, "kw_only", dataclasses.MISSING), # for python <3.9 - serialize=field.repr, - serialization_fn=None, - loading_fn=None, - ) - - -# Step 2: Create a serializable_field function -# no type hint to avoid confusing mypy -def serializable_field(*args, **kwargs): # -> SerializableField: - """Create a new SerializableField - - note that if not using ZANJ, and you have a class inside a container, you MUST provide - `serialization_fn` and `loading_fn` to serialize and load the container. - ZANJ will automatically do this for you. - - ``` - default: Any | dataclasses._MISSING_TYPE = dataclasses.MISSING, - default_factory: Callable[[], Any] - | dataclasses._MISSING_TYPE = dataclasses.MISSING, - init: bool = True, - repr: bool = True, - hash: Optional[bool] = None, - compare: bool = True, - metadata: types.MappingProxyType | None = None, - kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING, - serialize: bool = True, - serialization_fn: Optional[Callable[[Any], Any]] = None, - loading_fn: Optional[Callable[[Any], Any]] = None, - assert_type: bool = True, - ``` - """ - return SerializableField(*args, **kwargs) - - -# credit to https://stackoverflow.com/questions/51743827/how-to-compare-equality-of-dataclasses-holding-numpy-ndarray-boola-b-raises -def array_safe_eq(a: Any, b: Any) -> bool: - """check if two objects are equal, account for if numpy arrays or torch tensors""" - if a is b: - return True - - if ( - str(type(a)) == "" - and str(type(b)) == "" - ) or ( - str(type(a)) == "" - and str(type(b)) == "" - ): - return (a == b).all() - - if ( - str(type(a)) == "" - and str(type(b)) == "" - ): - return a.equals(b) - - if isinstance(a, typing.Sequence) and isinstance(b, typing.Sequence): - return len(a) == len(b) and all(array_safe_eq(a1, b1) for a1, b1 in zip(a, b)) - - if isinstance(a, (dict, typing.Mapping)) and isinstance(b, (dict, typing.Mapping)): - return len(a) == len(b) and all( - array_safe_eq(k1, k2) and array_safe_eq(a[k1], b[k2]) - for k1, k2 in zip(a.keys(), b.keys()) - ) - - try: - return bool(a == b) - except (TypeError, ValueError) as e: - warnings.warn(f"Cannot compare {a} and {b} for equality\n{e}") - return NotImplemented # type: ignore[return-value] - - -def dc_eq( - dc1, - dc2, - except_when_class_mismatch: bool = False, - false_when_class_mismatch: bool = True, - except_when_field_mismatch: bool = False, -) -> bool: - """checks if two dataclasses which (might) hold numpy arrays are equal - - # Parameters: - - `dc1`: the first dataclass - - `dc2`: the second dataclass - - `except_when_class_mismatch: bool` - if `True`, will throw `TypeError` if the classes are different. - if not, will return false by default or attempt to compare the fields if `false_when_class_mismatch` is `False` - (default: `False`) - - `false_when_class_mismatch: bool` - only relevant if `except_when_class_mismatch` is `False`. - if `True`, will return `False` if the classes are different. - if `False`, will attempt to compare the fields. - - `except_when_field_mismatch: bool` - only relevant if `except_when_class_mismatch` is `False` and `false_when_class_mismatch` is `False`. - if `True`, will throw `TypeError` if the fields are different. - (default: `True`) - - # Returns: - - `bool`: True if the dataclasses are equal, False otherwise - - # Raises: - - `TypeError`: if the dataclasses are of different classes - - `AttributeError`: if the dataclasses have different fields - - ``` - [START] - ā–¼ - ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” - ā”‚dc1 is dc2?ā”œā”€ā–ŗā”‚ classes ā”‚ - ā””ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜Noā”‚ match? ā”‚ - ā”€ā”€ā”€ā”€ ā”‚ ā”œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¤ - (True)ā—„ā”€ā”€ā”˜Yes ā”‚No ā”‚Yes - ā”€ā”€ā”€ā”€ ā–¼ ā–¼ - ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” - ā”‚ except when ā”‚ ā”‚ fields keysā”‚ - ā”‚ class mismatch?ā”‚ ā”‚ match? ā”‚ - ā”œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”˜ ā”œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”˜ - ā”‚Yes ā”‚No ā”‚No ā”‚Yes - ā–¼ ā–¼ ā–¼ ā–¼ - ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” - { raise } ā”‚ except ā”‚ ā”‚ field ā”‚ - { TypeError } ā”‚ when ā”‚ ā”‚ values ā”‚ - ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ ā”‚ field ā”‚ ā”‚ match? ā”‚ - ā”‚ mismatch?ā”‚ ā”œā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”˜ - ā”œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”˜ ā”‚ ā”‚Yes - ā”‚Yes ā”‚No ā”‚No ā–¼ - ā–¼ ā–¼ ā”‚ ā”€ā”€ā”€ā”€ - ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ ā”€ā”€ā”€ā”€ā”€ ā”‚ (True) - { raise } (False)ā—„ā”˜ ā”€ā”€ā”€ā”€ - { AttributeError} ā”€ā”€ā”€ā”€ā”€ - ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ - ``` - - """ - if dc1 is dc2: - return True - - if dc1.__class__ is not dc2.__class__: - if except_when_class_mismatch: - # if the classes don't match, raise an error - raise TypeError( - f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`" - ) - else: - dc1_fields: set = set([fld.name for fld in dataclasses.fields(dc1)]) - dc2_fields: set = set([fld.name for fld in dataclasses.fields(dc2)]) - fields_match: bool = set(dc1_fields) == set(dc2_fields) - - if not fields_match: - # if the fields match, keep going - if except_when_field_mismatch: - raise AttributeError( - f"dataclasses {dc1} and {dc2} have different fields: `{dc1_fields}` and `{dc2_fields}`" - ) - else: - return False - - return all( - array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name)) - for fld in dataclasses.fields(dc1) - if fld.compare - ) - - T = TypeVar("T") @@ -278,7 +26,7 @@ class ZanjMissingWarning(UserWarning): _zanj_loading_needs_import: bool = True -def zanj_register_loader_serializable_dataclass(cls: Type[T]): +def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]): """Register a serializable dataclass with the ZANJ backport diff --git a/muutils/json_serialize/serializable_field.py b/muutils/json_serialize/serializable_field.py new file mode 100644 index 00000000..86d0a1f1 --- /dev/null +++ b/muutils/json_serialize/serializable_field.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +import abc +import dataclasses +import functools +import json +import sys +import types +import typing +import warnings +from typing import Any, Callable, Optional, Type, TypeVar, Union + +from muutils.validate_type import validate_type + +# pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access + + +class SerializableField(dataclasses.Field): + """extension of `dataclasses.Field` with additional serialization properties""" + + __slots__ = ( + # from dataclasses.Field.__slots__ + "name", + "type", + "default", + "default_factory", + "repr", + "hash", + "init", + "compare", + "metadata", + "kw_only", + "_field_type", # Private: not to be used by user code. + # new ones + "serialize", + "serialization_fn", + "loading_fn", + "assert_type", + ) + + def __init__( + self, + default: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING, + default_factory: Union[ + Callable[[], Any], dataclasses._MISSING_TYPE + ] = dataclasses.MISSING, + init: bool = True, + repr: bool = True, + hash: Optional[bool] = None, + compare: bool = True, + # TODO: add field for custom comparator (such as serializing) + metadata: Optional[types.MappingProxyType] = None, + kw_only: Union[bool, dataclasses._MISSING_TYPE] = dataclasses.MISSING, + serialize: bool = True, + serialization_fn: Optional[Callable[[Any], Any]] = None, + loading_fn: Optional[Callable[[Any], Any]] = None, + assert_type: bool = True, + # TODO: add field for custom type assertion + ): + # TODO: should we do this check, or assume the user knows what they are doing? + if init and not serialize: + raise ValueError("Cannot have init=True and serialize=False") + + # need to assemble kwargs in this hacky way so as not to upset type checking + super_kwargs: dict[str, Any] = dict( + default=default, + default_factory=default_factory, + init=init, + repr=repr, + hash=hash, + compare=compare, + kw_only=kw_only, + ) + + if metadata is not None: + super_kwargs["metadata"] = metadata + else: + super_kwargs["metadata"] = types.MappingProxyType({}) + + # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy + if sys.version_info < (3, 10): + if super_kwargs["kw_only"] == True: # noqa: E712 + raise ValueError("kw_only is not supported in python >=3.9") + else: + del super_kwargs["kw_only"] + + # actually init the super class + super().__init__(**super_kwargs) # type: ignore[call-arg] + + # now init the new fields + self.serialize: bool = serialize + self.serialization_fn: Optional[Callable[[Any], Any]] = serialization_fn + self.loading_fn: Optional[Callable[[Any], Any]] = loading_fn + self.assert_type: bool = assert_type + + @classmethod + def from_Field(cls, field: dataclasses.Field) -> "SerializableField": + """copy all values from a `dataclasses.Field` to new `SerializableField`""" + return cls( + default=field.default, + default_factory=field.default_factory, + init=field.init, + repr=field.repr, + hash=field.hash, + compare=field.compare, + metadata=field.metadata, + kw_only=getattr(field, "kw_only", dataclasses.MISSING), # for python <3.9 + serialize=field.repr, + serialization_fn=None, + loading_fn=None, + ) + + +# Step 2: Create a serializable_field function +# no type hint to avoid confusing mypy +def serializable_field(*args, **kwargs): # -> SerializableField: + """Create a new SerializableField + + note that if not using ZANJ, and you have a class inside a container, you MUST provide + `serialization_fn` and `loading_fn` to serialize and load the container. + ZANJ will automatically do this for you. + + ``` + default: Any | dataclasses._MISSING_TYPE = dataclasses.MISSING, + default_factory: Callable[[], Any] + | dataclasses._MISSING_TYPE = dataclasses.MISSING, + init: bool = True, + repr: bool = True, + hash: Optional[bool] = None, + compare: bool = True, + metadata: types.MappingProxyType | None = None, + kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING, + serialize: bool = True, + serialization_fn: Optional[Callable[[Any], Any]] = None, + loading_fn: Optional[Callable[[Any], Any]] = None, + assert_type: bool = True, + ``` + """ + return SerializableField(*args, **kwargs) \ No newline at end of file diff --git a/muutils/json_serialize/util.py b/muutils/json_serialize/util.py index e47ba597..632ae3a5 100644 --- a/muutils/json_serialize/util.py +++ b/muutils/json_serialize/util.py @@ -1,11 +1,12 @@ from __future__ import annotations +import dataclasses import functools import inspect import sys import typing import warnings -from typing import Any, Callable, Iterable, Literal, Union +from typing import Any, Callable, Iterable, Literal, TypeVar, Union _NUMPY_WORKING: bool try: @@ -127,3 +128,132 @@ def safe_getsource(func) -> list[str]: return string_as_lines(inspect.getsource(func)) except Exception as e: return string_as_lines(f"Error: Unable to retrieve source code:\n{e}") + + +# credit to https://stackoverflow.com/questions/51743827/how-to-compare-equality-of-dataclasses-holding-numpy-ndarray-boola-b-raises +def array_safe_eq(a: Any, b: Any) -> bool: + """check if two objects are equal, account for if numpy arrays or torch tensors""" + if a is b: + return True + + if ( + str(type(a)) == "" + and str(type(b)) == "" + ) or ( + str(type(a)) == "" + and str(type(b)) == "" + ): + return (a == b).all() + + if ( + str(type(a)) == "" + and str(type(b)) == "" + ): + return a.equals(b) + + if isinstance(a, typing.Sequence) and isinstance(b, typing.Sequence): + return len(a) == len(b) and all(array_safe_eq(a1, b1) for a1, b1 in zip(a, b)) + + if isinstance(a, (dict, typing.Mapping)) and isinstance(b, (dict, typing.Mapping)): + return len(a) == len(b) and all( + array_safe_eq(k1, k2) and array_safe_eq(a[k1], b[k2]) + for k1, k2 in zip(a.keys(), b.keys()) + ) + + try: + return bool(a == b) + except (TypeError, ValueError) as e: + warnings.warn(f"Cannot compare {a} and {b} for equality\n{e}") + return NotImplemented # type: ignore[return-value] + + +def dc_eq( + dc1, + dc2, + except_when_class_mismatch: bool = False, + false_when_class_mismatch: bool = True, + except_when_field_mismatch: bool = False, +) -> bool: + """checks if two dataclasses which (might) hold numpy arrays are equal + + # Parameters: + - `dc1`: the first dataclass + - `dc2`: the second dataclass + - `except_when_class_mismatch: bool` + if `True`, will throw `TypeError` if the classes are different. + if not, will return false by default or attempt to compare the fields if `false_when_class_mismatch` is `False` + (default: `False`) + - `false_when_class_mismatch: bool` + only relevant if `except_when_class_mismatch` is `False`. + if `True`, will return `False` if the classes are different. + if `False`, will attempt to compare the fields. + - `except_when_field_mismatch: bool` + only relevant if `except_when_class_mismatch` is `False` and `false_when_class_mismatch` is `False`. + if `True`, will throw `TypeError` if the fields are different. + (default: `True`) + + # Returns: + - `bool`: True if the dataclasses are equal, False otherwise + + # Raises: + - `TypeError`: if the dataclasses are of different classes + - `AttributeError`: if the dataclasses have different fields + + ``` + [START] + ā–¼ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + ā”‚dc1 is dc2?ā”œā”€ā–ŗā”‚ classes ā”‚ + ā””ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜Noā”‚ match? ā”‚ + ā”€ā”€ā”€ā”€ ā”‚ ā”œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¤ + (True)ā—„ā”€ā”€ā”˜Yes ā”‚No ā”‚Yes + ā”€ā”€ā”€ā”€ ā–¼ ā–¼ + ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + ā”‚ except when ā”‚ ā”‚ fields keysā”‚ + ā”‚ class mismatch?ā”‚ ā”‚ match? ā”‚ + ā”œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”˜ ā”œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”˜ + ā”‚Yes ā”‚No ā”‚No ā”‚Yes + ā–¼ ā–¼ ā–¼ ā–¼ + ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” + { raise } ā”‚ except ā”‚ ā”‚ field ā”‚ + { TypeError } ā”‚ when ā”‚ ā”‚ values ā”‚ + ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ ā”‚ field ā”‚ ā”‚ match? ā”‚ + ā”‚ mismatch?ā”‚ ā”œā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”˜ + ā”œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”˜ ā”‚ ā”‚Yes + ā”‚Yes ā”‚No ā”‚No ā–¼ + ā–¼ ā–¼ ā”‚ ā”€ā”€ā”€ā”€ + ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ ā”€ā”€ā”€ā”€ā”€ ā”‚ (True) + { raise } (False)ā—„ā”˜ ā”€ā”€ā”€ā”€ + { AttributeError} ā”€ā”€ā”€ā”€ā”€ + ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ + ``` + + """ + if dc1 is dc2: + return True + + if dc1.__class__ is not dc2.__class__: + if except_when_class_mismatch: + # if the classes don't match, raise an error + raise TypeError( + f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`" + ) + else: + dc1_fields: set = set([fld.name for fld in dataclasses.fields(dc1)]) + dc2_fields: set = set([fld.name for fld in dataclasses.fields(dc2)]) + fields_match: bool = set(dc1_fields) == set(dc2_fields) + + if not fields_match: + # if the fields match, keep going + if except_when_field_mismatch: + raise AttributeError( + f"dataclasses {dc1} and {dc2} have different fields: `{dc1_fields}` and `{dc2_fields}`" + ) + else: + return False + + return all( + array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name)) + for fld in dataclasses.fields(dc1) + if fld.compare + ) From 1596a82e96f33212b4a71bc23374980430a3a40c Mon Sep 17 00:00:00 2001 From: mivanit Date: Thu, 20 Jun 2024 15:40:18 -0700 Subject: [PATCH 138/158] error mode module --- muutils/errormode.py | 27 ++++++++ tests/unit/test_errormode.py | 115 +++++++++++++++++++++++++++++++++++ 2 files changed, 142 insertions(+) create mode 100644 muutils/errormode.py create mode 100644 tests/unit/test_errormode.py diff --git a/muutils/errormode.py b/muutils/errormode.py new file mode 100644 index 00000000..51f6f691 --- /dev/null +++ b/muutils/errormode.py @@ -0,0 +1,27 @@ +import typing +import warnings +from enum import Enum + +class ErrorMode(Enum): + EXCEPT = "except" + WARN = "warn" + IGNORE = "ignore" + + def process( + self, + msg: str, + except_cls: typing.Type[Exception] = ValueError, + warn_cls: typing.Type[Warning] = UserWarning, + except_from: typing.Optional[typing.Type[Exception]] = None, + ): + if self is ErrorMode.EXCEPT: + if except_from is not None: + raise except_cls(msg) from except_from + else: + raise except_cls(msg) + elif self is ErrorMode.WARN: + warnings.warn(msg, warn_cls) + elif self is ErrorMode.IGNORE: + pass + else: + raise ValueError(f"Unknown error mode {self}") \ No newline at end of file diff --git a/tests/unit/test_errormode.py b/tests/unit/test_errormode.py new file mode 100644 index 00000000..cd5f28c9 --- /dev/null +++ b/tests/unit/test_errormode.py @@ -0,0 +1,115 @@ +import typing +import warnings + +from muutils.errormode import ErrorMode + +import pytest + +""" +import typing +import warnings +from enum import Enum + +class ErrorMode(Enum): + EXCEPT = "except" + WARN = "warn" + IGNORE = "ignore" + + def process( + self, + msg: str, + except_cls: typing.Type[Exception] = ValueError, + warn_cls: typing.Type[Warning] = UserWarning, + except_from: typing.Optional[typing.Type[Exception]] = None, + ): + if self is ErrorMode.EXCEPT: + if except_from is not None: + raise except_from(msg) from except_from + else: + raise except_cls(msg) + elif self is ErrorMode.WARN: + warnings.warn(msg, warn_cls) + elif self is ErrorMode.IGNORE: + pass + else: + raise ValueError(f"Unknown error mode {self}") +""" + + + + +def test_except(): + with pytest.raises(ValueError): + ErrorMode.EXCEPT.process("test-except", except_cls=ValueError) + + with pytest.raises(TypeError): + ErrorMode.EXCEPT.process("test-except", except_cls=TypeError) + + with pytest.raises(RuntimeError): + ErrorMode.EXCEPT.process("test-except", except_cls=RuntimeError) + + with pytest.raises(KeyError): + ErrorMode.EXCEPT.process("test-except", except_cls=KeyError) + + with pytest.raises(KeyError): + ErrorMode.EXCEPT.process("test-except", except_cls=KeyError, except_from=ValueError("base exception")) + + +def test_warn(): + with pytest.warns(UserWarning): + ErrorMode.WARN.process("test-warn", warn_cls=UserWarning) + + with pytest.warns(Warning): + ErrorMode.WARN.process("test-warn", warn_cls=Warning) + + with pytest.warns(DeprecationWarning): + ErrorMode.WARN.process("test-warn", warn_cls=DeprecationWarning) + +def test_ignore(): + with warnings.catch_warnings(record=True) as w: + ErrorMode.IGNORE.process("test-ignore") + + ErrorMode.IGNORE.process("test-ignore", except_cls=ValueError) + ErrorMode.IGNORE.process("test-ignore", except_from=TypeError) + + ErrorMode.IGNORE.process("test-ignore", warn_cls=UserWarning) + + assert len(w) == 0, f"There should be no warnings: {w}" + +def test_except_custom(): + class MyCustomError(ValueError): + pass + + with pytest.raises(MyCustomError): + ErrorMode.EXCEPT.process("test-except", except_cls=MyCustomError) + +def test_warn_custom(): + class MyCustomWarning(Warning): + pass + + with pytest.warns(MyCustomWarning): + ErrorMode.WARN.process("test-warn", warn_cls=MyCustomWarning) + + +def test_invalid_mode(): + with pytest.raises(ValueError): + ErrorMode("invalid") + + +def test_except_mode_chained_exception(): + try: + # set up the base exception + try: + raise KeyError("base exception") + except Exception as base_exception: + # catch it, raise another exception with it as the cause + ErrorMode.EXCEPT.process("Test chained exception", except_cls=RuntimeError, except_from=base_exception) + # catch the outer exception + except RuntimeError as e: + assert str(e) == "Test chained exception" + # check that the cause is the base exception + assert isinstance(e.__cause__, KeyError) + assert repr(e.__cause__) == "KeyError('base exception')" + else: + assert False, "Expected RuntimeError with cause KeyError" + From 2ee10331ea110313f3a3d740356c254b452998bd Mon Sep 17 00:00:00 2001 From: mivanit Date: Thu, 20 Jun 2024 16:04:45 -0700 Subject: [PATCH 139/158] ErrorMode.from_any, more tests --- muutils/errormode.py | 41 +++++++++++++++++- tests/unit/test_errormode.py | 81 +++++++++++++++++++++++++++++++++++- 2 files changed, 120 insertions(+), 2 deletions(-) diff --git a/muutils/errormode.py b/muutils/errormode.py index 51f6f691..93ef8246 100644 --- a/muutils/errormode.py +++ b/muutils/errormode.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import typing import warnings from enum import Enum @@ -24,4 +26,41 @@ def process( elif self is ErrorMode.IGNORE: pass else: - raise ValueError(f"Unknown error mode {self}") \ No newline at end of file + raise ValueError(f"Unknown error mode {self}") + + @staticmethod + def from_any(mode: "str|ErrorMode", allow_aliases: bool = True) -> ErrorMode: + if isinstance(mode, ErrorMode): + return mode + elif isinstance(mode, str): + mode = mode.strip().lower() + if not allow_aliases: + try: + return ErrorMode(mode) + except ValueError as e: + raise KeyError(f"Unknown error mode {mode}") from e + else: + return ERROR_MODE_ALIASES[mode] + else: + raise TypeError(f"Expected {ErrorMode} or str, got {type(mode) = }") + + +ERROR_MODE_ALIASES: dict[str, ErrorMode] = { + # base + "except": ErrorMode.EXCEPT, + "warn": ErrorMode.WARN, + "ignore": ErrorMode.IGNORE, + # except + "e": ErrorMode.EXCEPT, + "error": ErrorMode.EXCEPT, + "err": ErrorMode.EXCEPT, + "raise": ErrorMode.EXCEPT, + # warn + "w": ErrorMode.WARN, + "warning": ErrorMode.WARN, + # ignore + "i": ErrorMode.IGNORE, + "silent": ErrorMode.IGNORE, + "quiet": ErrorMode.IGNORE, + "nothing": ErrorMode.IGNORE, +} \ No newline at end of file diff --git a/tests/unit/test_errormode.py b/tests/unit/test_errormode.py index cd5f28c9..0f9d965e 100644 --- a/tests/unit/test_errormode.py +++ b/tests/unit/test_errormode.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import typing import warnings -from muutils.errormode import ErrorMode +from muutils.errormode import ErrorMode, ERROR_MODE_ALIASES import pytest @@ -113,3 +115,80 @@ def test_except_mode_chained_exception(): else: assert False, "Expected RuntimeError with cause KeyError" + + + +@pytest.mark.parametrize("mode, expected_mode", [ + ("except", ErrorMode.EXCEPT), + ("warn", ErrorMode.WARN), + ("ignore", ErrorMode.IGNORE), + ("Except", ErrorMode.EXCEPT), + ("Warn", ErrorMode.WARN), + ("Ignore", ErrorMode.IGNORE), + (" \teXcEpT \n", ErrorMode.EXCEPT), + ("WaRn \t", ErrorMode.WARN), + (" \tIGNORE", ErrorMode.IGNORE), +]) +def test_from_any_strict_ok(mode, expected_mode): + assert ErrorMode.from_any(mode, allow_aliases=False) == expected_mode + +@pytest.mark.parametrize("mode, excepted_error", [ + (42, TypeError), + (42.0, TypeError), + (None, TypeError), + (object(), TypeError), + (True, TypeError), + (False, TypeError), + (["except"], TypeError), + ("invalid", KeyError), + (" \tinvalid", KeyError), + ("e", KeyError), + (" E", KeyError), + ("w", KeyError), + ("W", KeyError), + ("i", KeyError), + ("I", KeyError), + ("silent", KeyError), + ("Silent", KeyError), + ("quiet", KeyError), + ("Quiet", KeyError), + ("raise", KeyError), + ("Raise", KeyError), + ("error", KeyError), + ("Error", KeyError), + ("err", KeyError), + ("ErR\t", KeyError), + ("warning", KeyError), + ("Warning", KeyError), +]) +def test_from_any_strict_error(mode, excepted_error): + with pytest.raises(excepted_error): + ErrorMode.from_any(mode, allow_aliases=False) + + +@pytest.mark.parametrize("mode, expected_mode", [ + *list(ERROR_MODE_ALIASES.items()), + *list((a.upper(), b) for a, b in ERROR_MODE_ALIASES.items()), + *list((a.title(), b) for a, b in ERROR_MODE_ALIASES.items()), + *list((a.capitalize(), b) for a, b in ERROR_MODE_ALIASES.items()), + *list((f" \t{a} \t", b) for a, b in ERROR_MODE_ALIASES.items()), +]) +def test_from_any_aliases_ok(mode, expected_mode): + assert ErrorMode.from_any(mode) == expected_mode + assert ErrorMode.from_any(mode, allow_aliases=True) == expected_mode + + +@pytest.mark.parametrize("mode, excepted_error", [ + (42, TypeError), + (42.0, TypeError), + (None, TypeError), + (object(), TypeError), + (True, TypeError), + (False, TypeError), + (["except"], TypeError), + ("invalid", KeyError), + (" \tinvalid", KeyError), +]) +def test_from_any_aliases_error(mode, excepted_error): + with pytest.raises(excepted_error): + ErrorMode.from_any(mode, allow_aliases=True) \ No newline at end of file From 450900b1bcfd0628e7959c8f2643bba44d8bddc6 Mon Sep 17 00:00:00 2001 From: mivanit Date: Thu, 20 Jun 2024 16:32:35 -0700 Subject: [PATCH 140/158] switched from_any to classmethod --- muutils/errormode.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/muutils/errormode.py b/muutils/errormode.py index 93ef8246..4d64b402 100644 --- a/muutils/errormode.py +++ b/muutils/errormode.py @@ -28,8 +28,8 @@ def process( else: raise ValueError(f"Unknown error mode {self}") - @staticmethod - def from_any(mode: "str|ErrorMode", allow_aliases: bool = True) -> ErrorMode: + @classmethod + def from_any(cls, mode: "str|ErrorMode", allow_aliases: bool = True) -> ErrorMode: if isinstance(mode, ErrorMode): return mode elif isinstance(mode, str): From b0b98d75646a54b7f20fd09e24fb438d26c8edaf Mon Sep 17 00:00:00 2001 From: mivanit Date: Thu, 20 Jun 2024 16:32:42 -0700 Subject: [PATCH 141/158] wip --- muutils/dictmagic.py | 22 +++--- .../json_serialize/serializable_dataclass.py | 71 ++++++++----------- muutils/json_serialize/util.py | 1 - muutils/tensor_utils.py | 16 ++--- muutils/validate_type.py | 6 +- 5 files changed, 49 insertions(+), 67 deletions(-) diff --git a/muutils/dictmagic.py b/muutils/dictmagic.py index c34ac34d..f30cf2c0 100644 --- a/muutils/dictmagic.py +++ b/muutils/dictmagic.py @@ -18,6 +18,7 @@ _KT = TypeVar("_KT") _VT = TypeVar("_VT") +from muutils.errormode import ErrorMode class DefaulterDict(typing.Dict[_KT, _VT], Generic[_KT, _VT]): """like a defaultdict, but default_factory is passed the key as an argument""" @@ -144,14 +145,14 @@ def kwargs_to_nested_dict( kwargs_dict: dict[str, Any], sep: str = ".", strip_prefix: Optional[str] = None, - when_unknown_prefix: typing.Literal["raise", "warn", "ignore"] = "warn", + when_unknown_prefix: ErrorMode = ErrorMode.WARN, transform_key: Optional[Callable[[str], str]] = None, ) -> dict[str, Any]: """given kwargs from fire, convert them to a nested dict if strip_prefix is not None, then all keys must start with the prefix. by default, will warn if an unknown prefix is found, but can be set to raise an error or ignore it: - `when_unknown_prefix: typing.Literal["raise", "warn", "ignore"]` + `when_unknown_prefix: ErrorMode` Example: ```python @@ -172,25 +173,20 @@ def main(**kwargs): the separator to use for nested keys - `strip_prefix: Optional[str] = None` if not None, then all keys must start with this prefix - - `when_unknown_prefix: typing.Literal["raise", "warn", "ignore"] = "warn"` + - `when_unknown_prefix: ErrorMode = ErrorMode.WARN` what to do when an unknown prefix is found - `transform_key: Callable[[str], str] | None = None` a function to apply to each key before adding it to the dict (applied after stripping the prefix) """ + when_unknown_prefix: ErrorMode = ErrorMode.from_any(when_unknown_prefix) filtered_kwargs: dict[str, Any] = dict() for key, value in kwargs_dict.items(): if strip_prefix is not None: if not key.startswith(strip_prefix): - if when_unknown_prefix == "raise": - raise ValueError(f"key {key} does not start with {strip_prefix}") - elif when_unknown_prefix == "warn": - warnings.warn(f"key {key} does not start with {strip_prefix}") - elif when_unknown_prefix == "ignore": - pass - else: - raise ValueError( - f"when_unknown_prefix must be one of 'raise', 'warn', or 'ignore', got {when_unknown_prefix}" - ) + when_unknown_prefix.process( + f"key '{key}' does not start with '{strip_prefix}'", + except_cls=ValueError, + ) else: key = key[len(strip_prefix) :] diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index c267bd54..dc344a52 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -10,6 +10,7 @@ import warnings from typing import Any, Callable, Optional, Type, TypeVar, Union +from muutils.errormode import ErrorMode from muutils.validate_type import validate_type from muutils.json_serialize.serializable_field import SerializableField, serializable_field from muutils.json_serialize.util import array_safe_eq, dc_eq @@ -18,6 +19,9 @@ T = TypeVar("T") +class CantGetTypeHintsWarning(UserWarning): + pass + class ZanjMissingWarning(UserWarning): pass @@ -64,22 +68,21 @@ def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]): return lh -OnTypeAssertDo = typing.Literal["raise", "warn", "ignore"] -DEFAULT_ON_TYPE_ASSERT: OnTypeAssertDo = "warn" +_DEFAULT_ON_TYPECHECK_MISMATCH: ErrorMode = ErrorMode.WARN +_DEFAULT_ON_TYPECHECK_ERROR: ErrorMode = ErrorMode.EXCEPT def SerializableDataclass__validate_field_type( self: SerializableDataclass, field: SerializableField|str, - on_type_assert: OnTypeAssertDo = DEFAULT_ON_TYPE_ASSERT, + on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, ) -> bool: - # do nothing + on_typecheck_error: ErrorMode = ErrorMode.from_any(on_typecheck_error) + + # do nothing case if not field.assert_type: return True - - if on_type_assert == "ignore": - return True - + # get field if isinstance(field, str): field = self.__dataclass_fields__[field] @@ -127,19 +130,13 @@ def SerializableDataclass__validate_field_type( -def SerializableDataclass__validate_fields_types(self: SerializableDataclass, on_type_assert: OnTypeAssertDo = DEFAULT_ON_TYPE_ASSERT) -> bool: +def SerializableDataclass__validate_fields_types( + self: SerializableDataclass, + on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, + ) -> bool: """validate the types of the fields on a SerializableDataclass""" - - # arg validation - if on_type_assert not in ("raise", "warn", "ignore"): - raise ValueError( - f"Invalid value for {on_type_assert = }, expected 'raise', 'warn', or 'ignore'" - ) - - # do nothing if ignore - if on_type_assert == "ignore": - return - + on_typecheck_error: ErrorMode = ErrorMode.from_any(on_typecheck_error) + # if except, bundle the exceptions results: dict[str, bool] = dict() exceptions: dict[str, Exception] = dict() @@ -148,25 +145,20 @@ def SerializableDataclass__validate_fields_types(self: SerializableDataclass, on cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self) for field in cls_fields: try: - assert self.validate_field_type(field, on_type_assert) + results[field.name] = self.validate_field_type(field, on_typecheck_error) except Exception as e: exceptions[field.name] = e # figure out what to do with the exceptions if len(exceptions) > 0: - if on_type_assert in ("warn", "ignore"): - msg: str = ( - f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}" - + f"\n\t" + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]) - ) - if on_type_assert == "warn": - warnings.warn(msg) - else: - raise ValueError(msg) from exceptions[0] - else: - assert on_type_assert == "ignore" + on_typecheck_error.process( + f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}" + + f"\n\t" + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]), + except_cls=ValueError, + except_from=exceptions[0], + ) - return True + return all(results.values()) @@ -184,11 +176,11 @@ def serialize(self) -> dict[str, Any]: def load(cls: Type[T], data: dict[str, Any] | T) -> T: raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`") - def validate_fields_types(self, on_type_assert: OnTypeAssertDo = DEFAULT_ON_TYPE_ASSERT) -> bool: - return SerializableDataclass__validate_fields_types(self, on_type_assert) + def validate_fields_types(self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR) -> bool: + return SerializableDataclass__validate_fields_types(self, on_typecheck_error=on_typecheck_error) - def validate_field_type(self, field: SerializableField|str, on_type_assert: OnTypeAssertDo = DEFAULT_ON_TYPE_ASSERT) -> bool: - return SerializableDataclass__validate_field_type(self, field, on_type_assert) + def validate_field_type(self, field: "SerializableField|str", on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR) -> bool: + return SerializableDataclass__validate_field_type(self, field, on_typecheck_error=on_typecheck_error) def __eq__(self, other: Any) -> bool: return dc_eq(self, other) @@ -259,9 +251,6 @@ def __deepcopy__(self, memo: dict) -> "SerializableDataclass": return self.__class__.load(self.serialize()) -class CantGetTypeHintsWarning(UserWarning): - pass - # cache this so we don't have to keep getting it @functools.lru_cache(typed=True) @@ -477,7 +466,7 @@ def load(cls, data: dict[str, Any] | T) -> Type[T]: cls.serialize = serialize # type: ignore[attr-defined] # type is `Callable[[dict], T]` cls.load = load # type: ignore[attr-defined] - # type is `Callable[[T, OnTypeAssertDo], bool]` + # type is `Callable[[T, ErrorMode], bool]` cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined] # type is `Callable[[T, T], bool]` diff --git a/muutils/json_serialize/util.py b/muutils/json_serialize/util.py index 632ae3a5..31270766 100644 --- a/muutils/json_serialize/util.py +++ b/muutils/json_serialize/util.py @@ -16,7 +16,6 @@ _NUMPY_WORKING = False ErrorMode = Literal["ignore", "warn", "except"] -TypeErrorMode = Union[ErrorMode, Literal["try_convert"]] JSONitem = Union[bool, int, float, str, list, typing.Dict[str, Any], None] diff --git a/muutils/tensor_utils.py b/muutils/tensor_utils.py index abc61ea2..60af8825 100644 --- a/muutils/tensor_utils.py +++ b/muutils/tensor_utils.py @@ -8,6 +8,7 @@ import numpy as np import torch +from muutils.errormode import ErrorMode from muutils.dictmagic import dotlist_to_nested_dict # pylint: disable=missing-class-docstring @@ -64,7 +65,7 @@ def jaxtype_factory( name: str, array_type: type, default_jax_dtype=jaxtyping.Float, - legacy_mode: typing.Literal["error", "warn", "ignore"] = "warn", + legacy_mode: ErrorMode = ErrorMode.WARN, ) -> type: """usage: ``` @@ -72,6 +73,7 @@ def jaxtype_factory( x: ATensor["dim1 dim2", np.float32] ``` """ + legacy_mode = ErrorMode.from_any(legacy_mode) class _BaseArray: """jaxtyping shorthand @@ -117,14 +119,10 @@ def __class_getitem__(cls, params: typing.Union[str, tuple]) -> type: return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]] elif isinstance(params[0], tuple): - if legacy_mode == "error": - raise Exception( - f"legacy mode is set to error, but legacy type was used:\n{cls.param_info(params)}" - ) - elif legacy_mode == "warn": - warnings.warn( - f"legacy type annotation was used:\n{cls.param_info(params)}" - ) + legacy_mode.process( + f"legacy type annotation was used:\n{cls.param_info(params) = }", + except_cls=Exception, + ) # MyTensor[("dim1", "dim2"), int] shape_anot: list[str] = list() for x in params[0]: diff --git a/muutils/validate_type.py b/muutils/validate_type.py index 11fb6fcd..f1de76b0 100644 --- a/muutils/validate_type.py +++ b/muutils/validate_type.py @@ -133,16 +133,16 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: if origin is type: return True - + # TODO: Callables, etc. - raise ValueError( + raise NotImplementedError( f"Unsupported generic alias {expected_type = } for {value = }, {origin = }, {args = }", f"{GenericAliasTypes = }", ) else: - raise ValueError( + raise NotImplementedError( f"Unsupported type hint {expected_type = } for {value = }", f"{GenericAliasTypes = }", ) From 82f67b9dc91bdcbed9c0710fc1d9920db4316770 Mon Sep 17 00:00:00 2001 From: mivanit Date: Thu, 20 Jun 2024 17:11:49 -0700 Subject: [PATCH 142/158] wipgs! many docstrings improved too --- .../json_serialize/serializable_dataclass.py | 214 +++++++++++++++--- muutils/json_serialize/serializable_field.py | 30 ++- 2 files changed, 202 insertions(+), 42 deletions(-) diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index dc344a52..f3aaef89 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -77,6 +77,21 @@ def SerializableDataclass__validate_field_type( field: SerializableField|str, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, ) -> bool: + """given a dataclass, check the field matches the type hint + + # Parameters: + - `self : SerializableDataclass` + `SerializableDataclass` instance + - `field : SerializableField | str` + field to validate, will get from `self.__dataclass_fields__` if an `str` + - `on_typecheck_error : ErrorMode` + what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False` + (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`) + + # Returns: + - `bool` + if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore` + """ on_typecheck_error: ErrorMode = ErrorMode.from_any(on_typecheck_error) # do nothing case @@ -100,41 +115,50 @@ def SerializableDataclass__validate_field_type( # validate the type if field_type_hint is not None: try: - # validate the type - type_is_valid: bool = validate_type( - value, field_type_hint - ) + type_is_valid: bool + # validate the type with the default type validator + if field.custom_typecheck_fn is None: + type_is_valid = validate_type( + value, field_type_hint + ) + # validate the type with a custom type validator + else: + type_is_valid = field.custom_typecheck_fn(field_type_hint) - # if not valid, raise or warn depending on the setting in the SerializableDataclass - if not type_is_valid: - msg: str = f"Field '{field.name}' on class {self.__class__.__name__} has type {type(value)}, but expected {field_type_hint}" - if on_type_assert == "raise": - raise ValueError(msg) - else: - warnings.warn(msg) + return type_is_valid except Exception as e: - raise ValueError( + on_typecheck_error.process( "exception while validating type: " - + f"{field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }" - ) from e + + f"{field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }", + except_cls=ValueError, + except_from=e, + ) + return False else: - raise ValueError( - f"Cannot get type hints for {self.__class__.__name__}, field {field.name = } and so cannot validate." - + f"Python version is {sys.version_info = }. You can:\n" - + f" - disable `assert_type`. Currently: {field.assert_type = }\n" - + f" - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {field.type = }\n" - + " - use python 3.9.x or higher\n" - + " - coming in a future release, specify custom type validation functions\n" + on_typecheck_error.process( + ( + f"Cannot get type hints for {self.__class__.__name__}, field {field.name = } and so cannot validate." + + f"Python version is {sys.version_info = }. You can:\n" + + f" - disable `assert_type`. Currently: {field.assert_type = }\n" + + f" - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {field.type = }\n" + + " - use python 3.9.x or higher\n" + + " - coming in a future release, specify custom type validation functions\n" + ), + except_cls=ValueError, ) + return False -def SerializableDataclass__validate_fields_types( +def SerializableDataclass__validate_fields_types__dict( self: SerializableDataclass, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, - ) -> bool: - """validate the types of the fields on a SerializableDataclass""" + ) -> dict[str, bool]: + """validate the types of all the fields on a SerializableDataclass. calls `SerializableDataclass__validate_field_type` for each field + + returns a dict of field names to bools, where the bool is if the field type is valid + """ on_typecheck_error: ErrorMode = ErrorMode.from_any(on_typecheck_error) # if except, bundle the exceptions @@ -147,6 +171,7 @@ def SerializableDataclass__validate_fields_types( try: results[field.name] = self.validate_field_type(field, on_typecheck_error) except Exception as e: + results[field.name] = False exceptions[field.name] = e # figure out what to do with the exceptions @@ -158,9 +183,14 @@ def SerializableDataclass__validate_fields_types( except_from=exceptions[0], ) - return all(results.values()) + return results - +def SerializableDataclass__validate_fields_types( + self: SerializableDataclass, + on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, +) -> bool: + """validate the types of all the fields on a SerializableDataclass. calls `SerializableDataclass__validate_field_type` for each field""" + return all(SerializableDataclass__validate_fields_types__dict(self, on_typecheck_error=on_typecheck_error).values()) class SerializableDataclass(abc.ABC): @@ -191,28 +221,60 @@ def __hash__(self) -> int: def diff( self, other: "SerializableDataclass", of_serialized: bool = False ) -> dict[str, Any]: + """get a rich and recursive diff between two instances of a serializable dataclass + + ```python + >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) + {'b': {'self': 2, 'other': 3}} + >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) + {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} + ``` + + # Parameters: + - `other : SerializableDataclass` + other instance to compare against + - `of_serialized : bool` + if true, compare serialized data and not raw values + (defaults to `False`) + + # Returns: + - `dict[str, Any]` + + + # Raises: + - `ValueError` : if the instances are not of the same type + - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` + """ + # match types if type(self) != type(other): raise ValueError( - f"Instances must be of the same type, but got {type(self)} and {type(other)}" + f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" ) + # initialize the diff result diff_result: dict = {} + # if they are the same, return the empty diff if self == other: return diff_result + # if we are working with serialized data, serialize the instances if of_serialized: ser_self: dict = self.serialize() ser_other: dict = other.serialize() + # for each field in the class for field in dataclasses.fields(self): # type: ignore[arg-type] + # skip fields that are not for comparison if not field.compare: continue - + + # get values field_name: str = field.name self_value = getattr(self, field_name) other_value = getattr(other, field_name) + # if the values are both serializable dataclasses, recurse if isinstance(self_value, SerializableDataclass) and isinstance( other_value, SerializableDataclass ): @@ -221,19 +283,29 @@ def diff( ) if nested_diff: diff_result[field_name] = nested_diff + # only support serializable dataclasses elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( other_value ): raise ValueError("Non-serializable dataclass is not supported") else: + # get the values of either the serialized or the actual values self_value_s = ser_self[field_name] if of_serialized else self_value other_value_s = ser_other[field_name] if of_serialized else other_value + # compare the values if not array_safe_eq(self_value_s, other_value_s): diff_result[field_name] = {"self": self_value, "other": other_value} + # return the diff result return diff_result def update_from_nested_dict(self, nested_dict: dict[str, Any]): + """update the instance from a nested dict, useful for configuration from command line args + + # Parameters: + - `nested_dict : dict[str, Any]` + nested dict to update the instance with + """ for field in dataclasses.fields(self): # type: ignore[arg-type] field_name: str = field.name self_value = getattr(self, field_name) @@ -245,10 +317,12 @@ def update_from_nested_dict(self, nested_dict: dict[str, Any]): setattr(self, field_name, nested_dict[field_name]) def __copy__(self) -> "SerializableDataclass": - return self.__class__.load(self.serialize()) + "deep copy by serializing and loading the instance to json" + return self.__class__.load(json.loads(json.dumps(self.serialize()))) def __deepcopy__(self, memo: dict) -> "SerializableDataclass": - return self.__class__.load(self.serialize()) + "deep copy by serializing and loading the instance to json" + return self.__class__.load(json.loads(json.dumps(self.serialize()))) @@ -294,12 +368,72 @@ def serializable_dataclass( frozen: bool = False, properties_to_serialize: Optional[list[str]] = None, register_handler: bool = True, - on_type_assert: typing.Literal[ - "raise", "warn", "ignore" - ] = "warn", # TODO: change default to "raise" once more stable + on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, + on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH, **kwargs, ): + """decorator to make a dataclass serializable. must also make it inherit from `SerializableDataclass` + + types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE` + + behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs + + Returns the same class as was passed in, with dunder methods added based on the fields defined in the class. + + Examines PEP 526 __annotations__ to determine fields. + + If init is true, an __init__() method is added to the class. If repr is true, a __repr__() method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a __hash__() method function is added. If frozen is true, fields may not be assigned to after instance creation. + + ```python + @serializable_dataclass(kw_only=True) + class Myclass(SerializableDataclass): + a: int + b: str + ``` + ```python + >>> Myclass(a=1, b="q").serialize() + {'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'} + ``` + + # Parameters: + - `_cls : _type_` + class to decorate. don't pass this arg, just use this as a decorator + (defaults to `None`) + - `init : bool` + (defaults to `True`) + - `repr : bool` + (defaults to `True`) + - `order : bool` + (defaults to `False`) + - `unsafe_hash : bool` + (defaults to `False`) + - `frozen : bool` + (defaults to `False`) + - `properties_to_serialize : Optional[list[str]]` + **SerializableDataclass only:** which properties to add to the serialized data dict + (defaults to `None`) + - `register_handler : bool` + **SerializableDataclass only:** if true, register the class with ZANJ for loading + (defaults to `True`) + - `on_typecheck_error : ErrorMode` + **SerializableDataclass only:** what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false + - `on_typecheck_mismatch : ErrorMode` + **SerializableDataclass only:** what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True` + + # Returns: + - `_type_` + _description_ + + # Raises: + - `ValueError` : _description_ + - `ValueError` : _description_ + - `ValueError` : _description_ + - `AttributeError` : _description_ + - `ValueError` : _description_ + """ # -> Union[Callable[[Type[T]], Type[T]], Type[T]]: + on_typecheck_error = ErrorMode.from_any(on_typecheck_error) + on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch) if properties_to_serialize is None: _properties_to_serialize: list = list() @@ -327,6 +461,7 @@ def wrap(cls: Type[T]) -> Type[T]: else: del kwargs["kw_only"] + # call `dataclasses.dataclass` to set some stuff up cls = dataclasses.dataclass( # type: ignore[call-overload] cls, init=init, @@ -338,8 +473,13 @@ def wrap(cls: Type[T]) -> Type[T]: **kwargs, ) + # copy these to the class cls._properties_to_serialize = _properties_to_serialize.copy() # type: ignore[attr-defined] + # ====================================================================== + # define `serialize` func + # done locally since it depends on args to the decorator + # ====================================================================== def serialize(self) -> dict[str, Any]: result: dict[str, Any] = { "__format__": f"{self.__class__.__name__}(SerializableDataclass)" @@ -398,6 +538,10 @@ def serialize(self) -> dict[str, Any]: return result + # ====================================================================== + # define `load` func + # done locally since it depends on args to the decorator + # ====================================================================== # mypy thinks this isnt a classmethod @classmethod # type: ignore[misc] def load(cls, data: dict[str, Any] | T) -> Type[T]: @@ -455,8 +599,8 @@ def load(cls, data: dict[str, Any] | T) -> Type[T]: output: cls = cls(**ctor_kwargs) # validate the types of the fields if needed - if on_type_assert in ("raise", "warn"): - output.validate_fields_types() + if on_typecheck_mismatch != ErrorMode.IGNORE: + output.validate_fields_types(on_typecheck_error=on_typecheck_error) # return the new instance return output diff --git a/muutils/json_serialize/serializable_field.py b/muutils/json_serialize/serializable_field.py index 86d0a1f1..5fc8ac9d 100644 --- a/muutils/json_serialize/serializable_field.py +++ b/muutils/json_serialize/serializable_field.py @@ -36,6 +36,7 @@ class SerializableField(dataclasses.Field): "serialization_fn", "loading_fn", "assert_type", + "custom_typecheck_fn", ) def __init__( @@ -55,7 +56,7 @@ def __init__( serialization_fn: Optional[Callable[[Any], Any]] = None, loading_fn: Optional[Callable[[Any], Any]] = None, assert_type: bool = True, - # TODO: add field for custom type assertion + custom_typecheck_fn: Optional[Callable[[type], bool]] = None, ): # TODO: should we do this check, or assume the user knows what they are doing? if init and not serialize: @@ -92,6 +93,7 @@ def __init__( self.serialization_fn: Optional[Callable[[Any], Any]] = serialization_fn self.loading_fn: Optional[Callable[[Any], Any]] = loading_fn self.assert_type: bool = assert_type + self.custom_typecheck_fn: Optional[Callable[[type], bool]] = custom_typecheck_fn @classmethod def from_Field(cls, field: dataclasses.Field) -> "SerializableField": @@ -111,14 +113,9 @@ def from_Field(cls, field: dataclasses.Field) -> "SerializableField": ) -# Step 2: Create a serializable_field function # no type hint to avoid confusing mypy def serializable_field(*args, **kwargs): # -> SerializableField: - """Create a new SerializableField - - note that if not using ZANJ, and you have a class inside a container, you MUST provide - `serialization_fn` and `loading_fn` to serialize and load the container. - ZANJ will automatically do this for you. + """Create a new `SerializableField`. type hinting this func confuses mypy, so scroll down ``` default: Any | dataclasses._MISSING_TYPE = dataclasses.MISSING, @@ -130,10 +127,29 @@ def serializable_field(*args, **kwargs): # -> SerializableField: compare: bool = True, metadata: types.MappingProxyType | None = None, kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING, + # new in `SerializableField`, not in `dataclasses.Field` serialize: bool = True, serialization_fn: Optional[Callable[[Any], Any]] = None, loading_fn: Optional[Callable[[Any], Any]] = None, assert_type: bool = True, + custom_typecheck_fn: Optional[Callable[[type], bool]] = None, ``` + + # new Parameters: + - `serialize`: whether to serialize this field when serializing the class' + - `serialization_fn`: function taking the instance of the field and returning a serializable object. If not provided, will iterate through the `SerializerHandler`s defined in `muutils.json_serialize.json_serialize` + - `loading_fn`: function taking the serialized object and returning the instance of the field. If not provided, will take object as-is. + + # Gotchas: + - `loading_fn` takes the dict of the **class**, not the field. if you wanted a `loading_fn` that does nothing, you'd write: + ```python + class MyClass: + my_field: int = serializable_field(loading_fn=lambda x["my_field"]: x) + ``` + issue to add a different way of doing this: https://github.com/mivanit/muutils/issues/40 + + note that if not using ZANJ, and you have a class inside a container, you MUST provide + `serialization_fn` and `loading_fn` to serialize and load the container. + ZANJ will automatically do this for you. """ return SerializableField(*args, **kwargs) \ No newline at end of file From 9456ee465c85a86effc9fbb6c49b13f0eacc0ed1 Mon Sep 17 00:00:00 2001 From: mivanit Date: Thu, 20 Jun 2024 17:19:56 -0700 Subject: [PATCH 143/158] fixed non-init and non-serialize field, hacky --- muutils/json_serialize/serializable_dataclass.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index f3aaef89..a3c4e431 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -97,6 +97,14 @@ def SerializableDataclass__validate_field_type( # do nothing case if not field.assert_type: return True + + # if field is not `init` or not `serialize`, skip but warn + # TODO: how to handle fields which are not `init` or `serialize`? + if not field.init or not field.serialize: + warnings.warn( + f"Field '{field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked" + ) + return True # get field if isinstance(field, str): @@ -180,7 +188,8 @@ def SerializableDataclass__validate_fields_types__dict( f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}" + f"\n\t" + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]), except_cls=ValueError, - except_from=exceptions[0], + # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict + except_from=list(exceptions.values())[0], ) return results From 1721773f97d43c0fc6dd5d525580c713f556debe Mon Sep 17 00:00:00 2001 From: mivanit Date: Thu, 20 Jun 2024 17:20:34 -0700 Subject: [PATCH 144/158] run format --- muutils/dictmagic.py | 3 +- muutils/errormode.py | 21 +-- .../json_serialize/serializable_dataclass.py | 138 +++++++++------- muutils/json_serialize/serializable_field.py | 12 +- muutils/json_serialize/util.py | 2 +- muutils/tensor_utils.py | 1 - muutils/validate_type.py | 1 - tests/unit/test_errormode.py | 151 ++++++++++-------- 8 files changed, 179 insertions(+), 150 deletions(-) diff --git a/muutils/dictmagic.py b/muutils/dictmagic.py index f30cf2c0..07746e5e 100644 --- a/muutils/dictmagic.py +++ b/muutils/dictmagic.py @@ -15,10 +15,11 @@ Union, ) +from muutils.errormode import ErrorMode + _KT = TypeVar("_KT") _VT = TypeVar("_VT") -from muutils.errormode import ErrorMode class DefaulterDict(typing.Dict[_KT, _VT], Generic[_KT, _VT]): """like a defaultdict, but default_factory is passed the key as an argument""" diff --git a/muutils/errormode.py b/muutils/errormode.py index 4d64b402..429a188a 100644 --- a/muutils/errormode.py +++ b/muutils/errormode.py @@ -4,18 +4,19 @@ import warnings from enum import Enum + class ErrorMode(Enum): EXCEPT = "except" WARN = "warn" IGNORE = "ignore" - + def process( - self, - msg: str, - except_cls: typing.Type[Exception] = ValueError, - warn_cls: typing.Type[Warning] = UserWarning, - except_from: typing.Optional[typing.Type[Exception]] = None, - ): + self, + msg: str, + except_cls: typing.Type[Exception] = ValueError, + warn_cls: typing.Type[Warning] = UserWarning, + except_from: typing.Optional[typing.Type[Exception]] = None, + ): if self is ErrorMode.EXCEPT: if except_from is not None: raise except_cls(msg) from except_from @@ -27,7 +28,7 @@ def process( pass else: raise ValueError(f"Unknown error mode {self}") - + @classmethod def from_any(cls, mode: "str|ErrorMode", allow_aliases: bool = True) -> ErrorMode: if isinstance(mode, ErrorMode): @@ -43,7 +44,7 @@ def from_any(cls, mode: "str|ErrorMode", allow_aliases: bool = True) -> ErrorMod return ERROR_MODE_ALIASES[mode] else: raise TypeError(f"Expected {ErrorMode} or str, got {type(mode) = }") - + ERROR_MODE_ALIASES: dict[str, ErrorMode] = { # base @@ -63,4 +64,4 @@ def from_any(cls, mode: "str|ErrorMode", allow_aliases: bool = True) -> ErrorMod "silent": ErrorMode.IGNORE, "quiet": ErrorMode.IGNORE, "nothing": ErrorMode.IGNORE, -} \ No newline at end of file +} diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index a3c4e431..537d9269 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -5,20 +5,23 @@ import functools import json import sys -import types import typing import warnings -from typing import Any, Callable, Optional, Type, TypeVar, Union +from typing import Any, Optional, Type, TypeVar from muutils.errormode import ErrorMode from muutils.validate_type import validate_type -from muutils.json_serialize.serializable_field import SerializableField, serializable_field +from muutils.json_serialize.serializable_field import ( + SerializableField, + serializable_field, +) from muutils.json_serialize.util import array_safe_eq, dc_eq # pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access T = TypeVar("T") + class CantGetTypeHintsWarning(UserWarning): pass @@ -72,32 +75,33 @@ def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]): _DEFAULT_ON_TYPECHECK_MISMATCH: ErrorMode = ErrorMode.WARN _DEFAULT_ON_TYPECHECK_ERROR: ErrorMode = ErrorMode.EXCEPT + def SerializableDataclass__validate_field_type( self: SerializableDataclass, - field: SerializableField|str, + field: SerializableField | str, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, ) -> bool: """given a dataclass, check the field matches the type hint - + # Parameters: - - `self : SerializableDataclass` + - `self : SerializableDataclass` `SerializableDataclass` instance - - `field : SerializableField | str` + - `field : SerializableField | str` field to validate, will get from `self.__dataclass_fields__` if an `str` - - `on_typecheck_error : ErrorMode` + - `on_typecheck_error : ErrorMode` what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False` (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`) - + # Returns: - - `bool` + - `bool` if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore` - """ + """ on_typecheck_error: ErrorMode = ErrorMode.from_any(on_typecheck_error) # do nothing case if not field.assert_type: return True - + # if field is not `init` or not `serialize`, skip but warn # TODO: how to handle fields which are not `init` or `serialize`? if not field.init or not field.serialize: @@ -119,16 +123,14 @@ def SerializableDataclass__validate_field_type( # get the value value: Any = getattr(self, field.name) - - # validate the type + + # validate the type if field_type_hint is not None: try: type_is_valid: bool # validate the type with the default type validator if field.custom_typecheck_fn is None: - type_is_valid = validate_type( - value, field_type_hint - ) + type_is_valid = validate_type(value, field_type_hint) # validate the type with a custom type validator else: type_is_valid = field.custom_typecheck_fn(field_type_hint) @@ -158,21 +160,20 @@ def SerializableDataclass__validate_field_type( return False - def SerializableDataclass__validate_fields_types__dict( - self: SerializableDataclass, - on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, - ) -> dict[str, bool]: + self: SerializableDataclass, + on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, +) -> dict[str, bool]: """validate the types of all the fields on a SerializableDataclass. calls `SerializableDataclass__validate_field_type` for each field - + returns a dict of field names to bools, where the bool is if the field type is valid """ on_typecheck_error: ErrorMode = ErrorMode.from_any(on_typecheck_error) - + # if except, bundle the exceptions results: dict[str, bool] = dict() exceptions: dict[str, Exception] = dict() - + # for each field in the class cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self) for field in cls_fields: @@ -186,20 +187,26 @@ def SerializableDataclass__validate_fields_types__dict( if len(exceptions) > 0: on_typecheck_error.process( f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}" - + f"\n\t" + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]), + + "\n\t" + + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]), except_cls=ValueError, # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict except_from=list(exceptions.values())[0], ) - + return results + def SerializableDataclass__validate_fields_types( self: SerializableDataclass, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, ) -> bool: """validate the types of all the fields on a SerializableDataclass. calls `SerializableDataclass__validate_field_type` for each field""" - return all(SerializableDataclass__validate_fields_types__dict(self, on_typecheck_error=on_typecheck_error).values()) + return all( + SerializableDataclass__validate_fields_types__dict( + self, on_typecheck_error=on_typecheck_error + ).values() + ) class SerializableDataclass(abc.ABC): @@ -209,17 +216,29 @@ class SerializableDataclass(abc.ABC): """ def serialize(self) -> dict[str, Any]: - raise NotImplementedError(f"decorate {self.__class__ = } with `@serializable_dataclass`") + raise NotImplementedError( + f"decorate {self.__class__ = } with `@serializable_dataclass`" + ) @classmethod def load(cls: Type[T], data: dict[str, Any] | T) -> T: raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`") - - def validate_fields_types(self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR) -> bool: - return SerializableDataclass__validate_fields_types(self, on_typecheck_error=on_typecheck_error) - - def validate_field_type(self, field: "SerializableField|str", on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR) -> bool: - return SerializableDataclass__validate_field_type(self, field, on_typecheck_error=on_typecheck_error) + + def validate_fields_types( + self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR + ) -> bool: + return SerializableDataclass__validate_fields_types( + self, on_typecheck_error=on_typecheck_error + ) + + def validate_field_type( + self, + field: "SerializableField|str", + on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, + ) -> bool: + return SerializableDataclass__validate_field_type( + self, field, on_typecheck_error=on_typecheck_error + ) def __eq__(self, other: Any) -> bool: return dc_eq(self, other) @@ -231,7 +250,7 @@ def diff( self, other: "SerializableDataclass", of_serialized: bool = False ) -> dict[str, Any]: """get a rich and recursive diff between two instances of a serializable dataclass - + ```python >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) {'b': {'self': 2, 'other': 3}} @@ -240,20 +259,20 @@ def diff( ``` # Parameters: - - `other : SerializableDataclass` + - `other : SerializableDataclass` other instance to compare against - - `of_serialized : bool` + - `of_serialized : bool` if true, compare serialized data and not raw values (defaults to `False`) - + # Returns: - - `dict[str, Any]` - - + - `dict[str, Any]` + + # Raises: - `ValueError` : if the instances are not of the same type - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` - """ + """ # match types if type(self) != type(other): raise ValueError( @@ -277,7 +296,7 @@ def diff( # skip fields that are not for comparison if not field.compare: continue - + # get values field_name: str = field.name self_value = getattr(self, field_name) @@ -310,7 +329,7 @@ def diff( def update_from_nested_dict(self, nested_dict: dict[str, Any]): """update the instance from a nested dict, useful for configuration from command line args - + # Parameters: - `nested_dict : dict[str, Any]` nested dict to update the instance with @@ -334,7 +353,6 @@ def __deepcopy__(self, memo: dict) -> "SerializableDataclass": return self.__class__.load(json.loads(json.dumps(self.serialize()))) - # cache this so we don't have to keep getting it @functools.lru_cache(typed=True) def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]: @@ -360,7 +378,7 @@ def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]: + f" {dataclasses.fields(cls) = }\n" + f" {e = }" ) from e - + return cls_type_hints @@ -392,7 +410,7 @@ def serializable_dataclass( Examines PEP 526 __annotations__ to determine fields. If init is true, an __init__() method is added to the class. If repr is true, a __repr__() method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a __hash__() method function is added. If frozen is true, fields may not be assigned to after instance creation. - + ```python @serializable_dataclass(kw_only=True) class Myclass(SerializableDataclass): @@ -403,43 +421,43 @@ class Myclass(SerializableDataclass): >>> Myclass(a=1, b="q").serialize() {'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'} ``` - + # Parameters: - - `_cls : _type_` + - `_cls : _type_` class to decorate. don't pass this arg, just use this as a decorator (defaults to `None`) - - `init : bool` + - `init : bool` (defaults to `True`) - - `repr : bool` + - `repr : bool` (defaults to `True`) - - `order : bool` + - `order : bool` (defaults to `False`) - - `unsafe_hash : bool` + - `unsafe_hash : bool` (defaults to `False`) - - `frozen : bool` + - `frozen : bool` (defaults to `False`) - - `properties_to_serialize : Optional[list[str]]` + - `properties_to_serialize : Optional[list[str]]` **SerializableDataclass only:** which properties to add to the serialized data dict (defaults to `None`) - - `register_handler : bool` + - `register_handler : bool` **SerializableDataclass only:** if true, register the class with ZANJ for loading (defaults to `True`) - `on_typecheck_error : ErrorMode` **SerializableDataclass only:** what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false - `on_typecheck_mismatch : ErrorMode` **SerializableDataclass only:** what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True` - + # Returns: - - `_type_` + - `_type_` _description_ - + # Raises: - `ValueError` : _description_ - `ValueError` : _description_ - `ValueError` : _description_ - `AttributeError` : _description_ - `ValueError` : _description_ - """ + """ # -> Union[Callable[[Type[T]], Type[T]], Type[T]]: on_typecheck_error = ErrorMode.from_any(on_typecheck_error) on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch) diff --git a/muutils/json_serialize/serializable_field.py b/muutils/json_serialize/serializable_field.py index 5fc8ac9d..54896d6f 100644 --- a/muutils/json_serialize/serializable_field.py +++ b/muutils/json_serialize/serializable_field.py @@ -1,16 +1,10 @@ from __future__ import annotations -import abc import dataclasses -import functools -import json import sys import types -import typing -import warnings -from typing import Any, Callable, Optional, Type, TypeVar, Union +from typing import Any, Callable, Optional, Union -from muutils.validate_type import validate_type # pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access @@ -148,8 +142,8 @@ class MyClass: ``` issue to add a different way of doing this: https://github.com/mivanit/muutils/issues/40 - note that if not using ZANJ, and you have a class inside a container, you MUST provide + note that if not using ZANJ, and you have a class inside a container, you MUST provide `serialization_fn` and `loading_fn` to serialize and load the container. ZANJ will automatically do this for you. """ - return SerializableField(*args, **kwargs) \ No newline at end of file + return SerializableField(*args, **kwargs) diff --git a/muutils/json_serialize/util.py b/muutils/json_serialize/util.py index 31270766..306fd6d6 100644 --- a/muutils/json_serialize/util.py +++ b/muutils/json_serialize/util.py @@ -6,7 +6,7 @@ import sys import typing import warnings -from typing import Any, Callable, Iterable, Literal, TypeVar, Union +from typing import Any, Callable, Iterable, Literal, Union _NUMPY_WORKING: bool try: diff --git a/muutils/tensor_utils.py b/muutils/tensor_utils.py index 60af8825..c15b7730 100644 --- a/muutils/tensor_utils.py +++ b/muutils/tensor_utils.py @@ -2,7 +2,6 @@ import json import typing -import warnings import jaxtyping import numpy as np diff --git a/muutils/validate_type.py b/muutils/validate_type.py index f1de76b0..24f5bad4 100644 --- a/muutils/validate_type.py +++ b/muutils/validate_type.py @@ -132,7 +132,6 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: # check is type if origin is type: return True - # TODO: Callables, etc. diff --git a/tests/unit/test_errormode.py b/tests/unit/test_errormode.py index 0f9d965e..a6d4e272 100644 --- a/tests/unit/test_errormode.py +++ b/tests/unit/test_errormode.py @@ -1,6 +1,5 @@ from __future__ import annotations -import typing import warnings from muutils.errormode import ErrorMode, ERROR_MODE_ALIASES @@ -38,8 +37,6 @@ def process( """ - - def test_except(): with pytest.raises(ValueError): ErrorMode.EXCEPT.process("test-except", except_cls=ValueError) @@ -49,12 +46,14 @@ def test_except(): with pytest.raises(RuntimeError): ErrorMode.EXCEPT.process("test-except", except_cls=RuntimeError) - + with pytest.raises(KeyError): ErrorMode.EXCEPT.process("test-except", except_cls=KeyError) with pytest.raises(KeyError): - ErrorMode.EXCEPT.process("test-except", except_cls=KeyError, except_from=ValueError("base exception")) + ErrorMode.EXCEPT.process( + "test-except", except_cls=KeyError, except_from=ValueError("base exception") + ) def test_warn(): @@ -67,6 +66,7 @@ def test_warn(): with pytest.warns(DeprecationWarning): ErrorMode.WARN.process("test-warn", warn_cls=DeprecationWarning) + def test_ignore(): with warnings.catch_warnings(record=True) as w: ErrorMode.IGNORE.process("test-ignore") @@ -78,6 +78,7 @@ def test_ignore(): assert len(w) == 0, f"There should be no warnings: {w}" + def test_except_custom(): class MyCustomError(ValueError): pass @@ -85,6 +86,7 @@ class MyCustomError(ValueError): with pytest.raises(MyCustomError): ErrorMode.EXCEPT.process("test-except", except_cls=MyCustomError) + def test_warn_custom(): class MyCustomWarning(Warning): pass @@ -105,7 +107,11 @@ def test_except_mode_chained_exception(): raise KeyError("base exception") except Exception as base_exception: # catch it, raise another exception with it as the cause - ErrorMode.EXCEPT.process("Test chained exception", except_cls=RuntimeError, except_from=base_exception) + ErrorMode.EXCEPT.process( + "Test chained exception", + except_cls=RuntimeError, + except_from=base_exception, + ) # catch the outer exception except RuntimeError as e: assert str(e) == "Test chained exception" @@ -116,79 +122,90 @@ def test_except_mode_chained_exception(): assert False, "Expected RuntimeError with cause KeyError" - - -@pytest.mark.parametrize("mode, expected_mode", [ - ("except", ErrorMode.EXCEPT), - ("warn", ErrorMode.WARN), - ("ignore", ErrorMode.IGNORE), - ("Except", ErrorMode.EXCEPT), - ("Warn", ErrorMode.WARN), - ("Ignore", ErrorMode.IGNORE), - (" \teXcEpT \n", ErrorMode.EXCEPT), - ("WaRn \t", ErrorMode.WARN), - (" \tIGNORE", ErrorMode.IGNORE), -]) +@pytest.mark.parametrize( + "mode, expected_mode", + [ + ("except", ErrorMode.EXCEPT), + ("warn", ErrorMode.WARN), + ("ignore", ErrorMode.IGNORE), + ("Except", ErrorMode.EXCEPT), + ("Warn", ErrorMode.WARN), + ("Ignore", ErrorMode.IGNORE), + (" \teXcEpT \n", ErrorMode.EXCEPT), + ("WaRn \t", ErrorMode.WARN), + (" \tIGNORE", ErrorMode.IGNORE), + ], +) def test_from_any_strict_ok(mode, expected_mode): assert ErrorMode.from_any(mode, allow_aliases=False) == expected_mode -@pytest.mark.parametrize("mode, excepted_error", [ - (42, TypeError), - (42.0, TypeError), - (None, TypeError), - (object(), TypeError), - (True, TypeError), - (False, TypeError), - (["except"], TypeError), - ("invalid", KeyError), - (" \tinvalid", KeyError), - ("e", KeyError), - (" E", KeyError), - ("w", KeyError), - ("W", KeyError), - ("i", KeyError), - ("I", KeyError), - ("silent", KeyError), - ("Silent", KeyError), - ("quiet", KeyError), - ("Quiet", KeyError), - ("raise", KeyError), - ("Raise", KeyError), - ("error", KeyError), - ("Error", KeyError), - ("err", KeyError), - ("ErR\t", KeyError), - ("warning", KeyError), - ("Warning", KeyError), -]) + +@pytest.mark.parametrize( + "mode, excepted_error", + [ + (42, TypeError), + (42.0, TypeError), + (None, TypeError), + (object(), TypeError), + (True, TypeError), + (False, TypeError), + (["except"], TypeError), + ("invalid", KeyError), + (" \tinvalid", KeyError), + ("e", KeyError), + (" E", KeyError), + ("w", KeyError), + ("W", KeyError), + ("i", KeyError), + ("I", KeyError), + ("silent", KeyError), + ("Silent", KeyError), + ("quiet", KeyError), + ("Quiet", KeyError), + ("raise", KeyError), + ("Raise", KeyError), + ("error", KeyError), + ("Error", KeyError), + ("err", KeyError), + ("ErR\t", KeyError), + ("warning", KeyError), + ("Warning", KeyError), + ], +) def test_from_any_strict_error(mode, excepted_error): with pytest.raises(excepted_error): ErrorMode.from_any(mode, allow_aliases=False) -@pytest.mark.parametrize("mode, expected_mode", [ - *list(ERROR_MODE_ALIASES.items()), - *list((a.upper(), b) for a, b in ERROR_MODE_ALIASES.items()), - *list((a.title(), b) for a, b in ERROR_MODE_ALIASES.items()), - *list((a.capitalize(), b) for a, b in ERROR_MODE_ALIASES.items()), - *list((f" \t{a} \t", b) for a, b in ERROR_MODE_ALIASES.items()), -]) +@pytest.mark.parametrize( + "mode, expected_mode", + [ + *list(ERROR_MODE_ALIASES.items()), + *list((a.upper(), b) for a, b in ERROR_MODE_ALIASES.items()), + *list((a.title(), b) for a, b in ERROR_MODE_ALIASES.items()), + *list((a.capitalize(), b) for a, b in ERROR_MODE_ALIASES.items()), + *list((f" \t{a} \t", b) for a, b in ERROR_MODE_ALIASES.items()), + ], +) def test_from_any_aliases_ok(mode, expected_mode): assert ErrorMode.from_any(mode) == expected_mode assert ErrorMode.from_any(mode, allow_aliases=True) == expected_mode -@pytest.mark.parametrize("mode, excepted_error", [ - (42, TypeError), - (42.0, TypeError), - (None, TypeError), - (object(), TypeError), - (True, TypeError), - (False, TypeError), - (["except"], TypeError), - ("invalid", KeyError), - (" \tinvalid", KeyError), -]) +@pytest.mark.parametrize( + "mode, excepted_error", + [ + (42, TypeError), + (42.0, TypeError), + (None, TypeError), + (object(), TypeError), + (True, TypeError), + (False, TypeError), + (["except"], TypeError), + ("invalid", KeyError), + (" \tinvalid", KeyError), + ], +) def test_from_any_aliases_error(mode, excepted_error): with pytest.raises(excepted_error): - ErrorMode.from_any(mode, allow_aliases=True) \ No newline at end of file + ErrorMode.from_any(mode, allow_aliases=True) From 67f4cbbf213d2920ac6d1a2cc9de28a38f161705 Mon Sep 17 00:00:00 2001 From: mivanit Date: Thu, 20 Jun 2024 17:29:03 -0700 Subject: [PATCH 145/158] fix exception for unsupported --- tests/unit/validate_type/test_validate_type.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/validate_type/test_validate_type.py b/tests/unit/validate_type/test_validate_type.py index 769f85ca..b9402f58 100644 --- a/tests/unit/validate_type/test_validate_type.py +++ b/tests/unit/validate_type/test_validate_type.py @@ -248,7 +248,7 @@ def test_validate_type_tuple(value, expected_type, expected_result): ], ) def test_validate_type_unsupported_type_hint(value, expected_type): - with pytest.raises(ValueError): + with pytest.raises(NotImplementedError): validate_type(value, expected_type) print(f"Failed to except: {value = }, {expected_type = }") From a35b3f6fe5262baad2c4f2229fc8065f4a60e15e Mon Sep 17 00:00:00 2001 From: mivanit Date: Thu, 20 Jun 2024 17:38:51 -0700 Subject: [PATCH 146/158] special warning for field is not init or serialize --- muutils/json_serialize/serializable_dataclass.py | 7 ++++++- .../serializable_dataclass/test_serializable_dataclass.py | 8 +++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index 537d9269..7580fc8d 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -76,6 +76,10 @@ def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]): _DEFAULT_ON_TYPECHECK_ERROR: ErrorMode = ErrorMode.EXCEPT +class FieldIsNotInitOrSerializeWarning(UserWarning): + pass + + def SerializableDataclass__validate_field_type( self: SerializableDataclass, field: SerializableField | str, @@ -106,7 +110,8 @@ def SerializableDataclass__validate_field_type( # TODO: how to handle fields which are not `init` or `serialize`? if not field.init or not field.serialize: warnings.warn( - f"Field '{field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked" + f"Field '{field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked", + FieldIsNotInitOrSerializeWarning, ) return True diff --git a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py index cc67cb23..18f7254e 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py @@ -11,6 +11,10 @@ serializable_field, ) +from muutils.json_serialize.serializable_dataclass import ( + FieldIsNotInitOrSerializeWarning, +) + # pylint: disable=missing-class-docstring, unused-variable @@ -127,8 +131,10 @@ def test_field_options_serialization(field_options_instance): def test_field_options_loading(field_options_instance): + # ignore a `FieldIsNotInitOrSerializeWarning` serialized = field_options_instance.serialize() - loaded = FieldOptions.load(serialized) + with pytest.warns(FieldIsNotInitOrSerializeWarning): + loaded = FieldOptions.load(serialized) assert loaded == field_options_instance From 89030cf89c35728bfcef39b2f6611c9ea1db3220 Mon Sep 17 00:00:00 2001 From: mivanit Date: Thu, 20 Jun 2024 17:39:10 -0700 Subject: [PATCH 147/158] class type validation, tests for that and aliases --- muutils/validate_type.py | 4 +- .../unit/validate_type/test_validate_type.py | 127 +++++++++++++++++- 2 files changed, 123 insertions(+), 8 deletions(-) diff --git a/muutils/validate_type.py b/muutils/validate_type.py index 24f5bad4..4ab85f2a 100644 --- a/muutils/validate_type.py +++ b/muutils/validate_type.py @@ -130,8 +130,10 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: f"{GenericAliasTypes = }", ) # check is type - if origin is type: + if origin in value.__mro__: return True + else: + return False # TODO: Callables, etc. diff --git a/tests/unit/validate_type/test_validate_type.py b/tests/unit/validate_type/test_validate_type.py index b9402f58..8a2c06f4 100644 --- a/tests/unit/validate_type/test_validate_type.py +++ b/tests/unit/validate_type/test_validate_type.py @@ -395,13 +395,6 @@ def test_validate_type_complex(): ) -def test_validate_type_class(): - class Test: - def __init__(self, a: int, b: str): - self.a: int = a - self.b: str = b - - @pytest.mark.parametrize( "value, expected_type, expected_result", [ @@ -438,3 +431,123 @@ def test_validate_type_nested(value, expected_type, expected_result): raise Exception( f"{value = }, {expected_type = }, {expected_result = }, {e}" ) from e + + +def test_validate_type_inheritance(): + class Parent: + def __init__(self, a: int, b: str): + self.a: int = a + self.b: str = b + + class Child: + def __init__(self, a: int, b: str): + self.a: int = 2 * a + self.b: str = b + + assert validate_type(Parent(1, "a"), Parent) + assert validate_type(Child(1, "a"), Parent) + assert validate_type(Child(1, "a"), Child) + assert not validate_type(Parent(1, "a"), Child) + + +def test_validate_type_class(): + class Parent: + def __init__(self, a: int, b: str): + self.a: int = a + self.b: str = b + + class Child: + def __init__(self, a: int, b: str): + self.a: int = 2 * a + self.b: str = b + + assert validate_type(Parent, type) + assert validate_type(Child, type) + assert validate_type(Parent, typing.Type[Parent]) + assert validate_type(Child, typing.Type[Child]) + assert not validate_type(Parent, typing.Type[Child]) + + assert validate_type(Child, typing.Union[typing.Type[Child], typing.Type[Parent]]) + assert validate_type(Child, typing.Union[typing.Type[Child], int]) + + +@pytest.mark.skip(reason="Not implemented") +def test_validate_type_class_union(): + class Parent: + def __init__(self, a: int, b: str): + self.a: int = a + self.b: str = b + + class Child: + def __init__(self, a: int, b: str): + self.a: int = 2 * a + self.b: str = b + + class Other: + def __init__(self, x: int, y: str): + self.x: int = x + self.y: str = y + + assert validate_type(Child, typing.Type[typing.Union[Child, Parent]]) + assert validate_type(Child, typing.Type[typing.Union[Child, Other]]) + assert validate_type(Parent, typing.Type[typing.Union[Child, Other]]) + assert validate_type(Parent, typing.Type[typing.Union[Parent, Other]]) + + +def test_validate_type_aliases(): + AliasInt = int + AliasStr = str + AliasListInt = typing.List[int] + AliasListStr = typing.List[str] + AliasDictIntStr = typing.Dict[int, str] + AliasDictStrInt = typing.Dict[str, int] + AliasTupleIntStr = typing.Tuple[int, str] + AliasTupleStrInt = typing.Tuple[str, int] + AliasSetInt = typing.Set[int] + AliasSetStr = typing.Set[str] + AliasUnionIntStr = typing.Union[int, str] + AliasUnionStrInt = typing.Union[str, int] + AliasOptionalInt = typing.Optional[int] + AliasOptionalStr = typing.Optional[str] + AliasOptionalListInt = typing.Optional[typing.List[int]] + AliasDictStrListInt = typing.Dict[str, typing.List[int]] + + assert validate_type(42, AliasInt) + assert not validate_type("42", AliasInt) + assert validate_type(42, AliasInt) + assert not validate_type("42", AliasInt) + assert validate_type("hello", AliasStr) + assert not validate_type(42, AliasStr) + assert validate_type([1, 2, 3], AliasListInt) + assert not validate_type([1, "2", 3], AliasListInt) + assert validate_type(["hello", "world"], AliasListStr) + assert not validate_type(["hello", 42], AliasListStr) + assert validate_type({1: "a", 2: "b"}, AliasDictIntStr) + assert not validate_type({1: 2, 3: 4}, AliasDictIntStr) + assert validate_type({"one": 1, "two": 2}, AliasDictStrInt) + assert not validate_type({1: "one", 2: "two"}, AliasDictStrInt) + assert validate_type((1, "a"), AliasTupleIntStr) + assert not validate_type(("a", 1), AliasTupleIntStr) + assert validate_type(("a", 1), AliasTupleStrInt) + assert not validate_type((1, "a"), AliasTupleStrInt) + assert validate_type({1, 2, 3}, AliasSetInt) + assert not validate_type({1, "two", 3}, AliasSetInt) + assert validate_type({"one", "two"}, AliasSetStr) + assert not validate_type({"one", 2}, AliasSetStr) + assert validate_type(42, AliasUnionIntStr) + assert validate_type("hello", AliasUnionIntStr) + assert not validate_type(3.14, AliasUnionIntStr) + assert validate_type("hello", AliasUnionStrInt) + assert validate_type(42, AliasUnionStrInt) + assert not validate_type(3.14, AliasUnionStrInt) + assert validate_type(42, AliasOptionalInt) + assert validate_type(None, AliasOptionalInt) + assert not validate_type("42", AliasOptionalInt) + assert validate_type("hello", AliasOptionalStr) + assert validate_type(None, AliasOptionalStr) + assert not validate_type(42, AliasOptionalStr) + assert validate_type([1, 2, 3], AliasOptionalListInt) + assert validate_type(None, AliasOptionalListInt) + assert not validate_type(["1", "2", "3"], AliasOptionalListInt) + assert validate_type({"key": [1, 2, 3]}, AliasDictStrListInt) + assert not validate_type({"key": [1, "2", 3]}, AliasDictStrListInt) From 709fc3cc326ebc7599ecac9b69487c9871122e11 Mon Sep 17 00:00:00 2001 From: mivanit Date: Thu, 20 Jun 2024 17:45:08 -0700 Subject: [PATCH 148/158] typing fixes --- muutils/dictmagic.py | 2 +- muutils/errormode.py | 2 +- .../json_serialize/serializable_dataclass.py | 49 ++++++++++--------- 3 files changed, 28 insertions(+), 25 deletions(-) diff --git a/muutils/dictmagic.py b/muutils/dictmagic.py index 07746e5e..37c48dc0 100644 --- a/muutils/dictmagic.py +++ b/muutils/dictmagic.py @@ -179,7 +179,7 @@ def main(**kwargs): - `transform_key: Callable[[str], str] | None = None` a function to apply to each key before adding it to the dict (applied after stripping the prefix) """ - when_unknown_prefix: ErrorMode = ErrorMode.from_any(when_unknown_prefix) + when_unknown_prefix = ErrorMode.from_any(when_unknown_prefix) filtered_kwargs: dict[str, Any] = dict() for key, value in kwargs_dict.items(): if strip_prefix is not None: diff --git a/muutils/errormode.py b/muutils/errormode.py index 429a188a..4499512d 100644 --- a/muutils/errormode.py +++ b/muutils/errormode.py @@ -15,7 +15,7 @@ def process( msg: str, except_cls: typing.Type[Exception] = ValueError, warn_cls: typing.Type[Warning] = UserWarning, - except_from: typing.Optional[typing.Type[Exception]] = None, + except_from: typing.Optional[Exception] = None, ): if self is ErrorMode.EXCEPT: if except_from is not None: diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index 7580fc8d..95178f42 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -100,52 +100,55 @@ def SerializableDataclass__validate_field_type( - `bool` if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore` """ - on_typecheck_error: ErrorMode = ErrorMode.from_any(on_typecheck_error) + on_typecheck_error = ErrorMode.from_any(on_typecheck_error) + + # get field + _field: SerializableField + if isinstance(field, str): + _field = self.__dataclass_fields__[field] # type: ignore[attr-defined] + else: + _field = field # do nothing case - if not field.assert_type: + if not _field.assert_type: return True # if field is not `init` or not `serialize`, skip but warn # TODO: how to handle fields which are not `init` or `serialize`? - if not field.init or not field.serialize: + if not _field.init or not _field.serialize: warnings.warn( - f"Field '{field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked", + f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked", FieldIsNotInitOrSerializeWarning, ) return True - # get field - if isinstance(field, str): - field = self.__dataclass_fields__[field] - assert isinstance( - field, SerializableField - ), f"Field '{field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(field) = }" + _field, SerializableField + ), f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }" # get field type hints - field_type_hint: Any = get_cls_type_hints(self.__class__).get(field.name, None) + field_type_hint: Any = get_cls_type_hints(self.__class__).get(_field.name, None) # get the value - value: Any = getattr(self, field.name) + value: Any = getattr(self, _field.name) # validate the type if field_type_hint is not None: try: type_is_valid: bool # validate the type with the default type validator - if field.custom_typecheck_fn is None: + if _field.custom_typecheck_fn is None: type_is_valid = validate_type(value, field_type_hint) # validate the type with a custom type validator else: - type_is_valid = field.custom_typecheck_fn(field_type_hint) + type_is_valid = _field.custom_typecheck_fn(field_type_hint) return type_is_valid except Exception as e: on_typecheck_error.process( "exception while validating type: " - + f"{field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }", + + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }", except_cls=ValueError, except_from=e, ) @@ -153,10 +156,10 @@ def SerializableDataclass__validate_field_type( else: on_typecheck_error.process( ( - f"Cannot get type hints for {self.__class__.__name__}, field {field.name = } and so cannot validate." + f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate." + f"Python version is {sys.version_info = }. You can:\n" - + f" - disable `assert_type`. Currently: {field.assert_type = }\n" - + f" - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {field.type = }\n" + + f" - disable `assert_type`. Currently: {_field.assert_type = }\n" + + f" - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n" + " - use python 3.9.x or higher\n" + " - coming in a future release, specify custom type validation functions\n" ), @@ -173,14 +176,14 @@ def SerializableDataclass__validate_fields_types__dict( returns a dict of field names to bools, where the bool is if the field type is valid """ - on_typecheck_error: ErrorMode = ErrorMode.from_any(on_typecheck_error) + on_typecheck_error = ErrorMode.from_any(on_typecheck_error) # if except, bundle the exceptions results: dict[str, bool] = dict() exceptions: dict[str, Exception] = dict() # for each field in the class - cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self) + cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self) # type: ignore[arg-type, assignment] for field in cls_fields: try: results[field.name] = self.validate_field_type(field, on_typecheck_error) @@ -373,14 +376,14 @@ def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]: + " - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x)\n" + " - use python 3.9.x or higher\n" + " - add explicit loading functions to the fields\n" - + f" {dataclasses.fields(cls) = }", + + f" {dataclasses.fields(cls) = }", # type: ignore[arg-type] CantGetTypeHintsWarning, ) cls_type_hints = dict() else: raise TypeError( f"Cannot get type hints for {cls.__name__}. Python version is {sys.version_info = }\n" - + f" {dataclasses.fields(cls) = }\n" + + f" {dataclasses.fields(cls) = }\n" # type: ignore[arg-type] + f" {e = }" ) from e @@ -517,7 +520,7 @@ def serialize(self) -> dict[str, Any]: "__format__": f"{self.__class__.__name__}(SerializableDataclass)" } # for each field in the class - for field in dataclasses.fields(self): + for field in dataclasses.fields(self): # type: ignore[arg-type] # need it to be our special SerializableField if not isinstance(field, SerializableField): raise ValueError( From 0ec0a6d0716b5eb53eaa8f27c861636183c2d715 Mon Sep 17 00:00:00 2001 From: mivanit Date: Fri, 21 Jun 2024 14:54:24 -0700 Subject: [PATCH 149/158] better generated modern type hint tests --- .gitignore | 2 +- makefile | 4 ++-- tests/util/replace_type_hints.py | 13 +++++++++---- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index 6508e22b..71f0c4d3 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ .github/local/** # this one is cursed -tests/unit/validate_type/test_validate_type_MODERN.py +tests/unit/validate_type/test_validate_type_GENERATED.py # test notebook _test.ipynb # junk data diff --git a/makefile b/makefile index 233fa9e2..00abc156 100644 --- a/makefile +++ b/makefile @@ -174,7 +174,7 @@ test: clean if [ $(COMPATIBILITY_MODE) -eq 0 ]; then \ echo "converting certain tests to modern format"; \ - $(PYTHON) tests/util/replace_type_hints.py tests/unit/validate_type/test_validate_type.py > tests/unit/validate_type/test_validate_type_MODERN.py; \ + $(PYTHON) tests/util/replace_type_hints.py tests/unit/validate_type/test_validate_type.py "# DO NOT EDIT, GENERATED FILE" > tests/unit/validate_type/test_validate_type_GENERATED.py; \ fi; \ $(PYTHON) -m pytest $(PYTEST_OPTIONS) $(TESTS_DIR) @@ -262,7 +262,7 @@ clean: rm -rf tests/junk_data $(PYTHON_BASE) -Bc "import pathlib; [p.unlink() for p in pathlib.Path('.').rglob('*.py[co]')]" $(PYTHON_BASE) -Bc "import pathlib; [p.rmdir() for p in pathlib.Path('.').rglob('__pycache__')]" - rm -rf tests/unit/validate_type/test_validate_type_MODERN.py + rm -rf tests/unit/validate_type/test_validate_type_GENERATED.py # listing targets, from stackoverflow # https://stackoverflow.com/questions/4219255/how-do-you-get-the-list-of-targets-in-a-makefile diff --git a/tests/util/replace_type_hints.py b/tests/util/replace_type_hints.py index c024dd18..64abb247 100644 --- a/tests/util/replace_type_hints.py +++ b/tests/util/replace_type_hints.py @@ -1,4 +1,4 @@ -def replace_typing_aliases(filename): +def replace_typing_aliases(filename: str) -> str: # Dictionary to map old types from the typing module to the new built-in types replacements = { "typing.List": "list", @@ -15,11 +15,16 @@ def replace_typing_aliases(filename): for old, new in replacements.items(): content = content.replace(old, new) - # Print the modified content to stdout - print(content) + # return the modified content + return content if __name__ == "__main__": import sys - replace_typing_aliases(sys.argv[1]) + file: str = sys.argv[1] + prefix: str = "" + if len(sys.argv) > 1: + prefix = "\n".join(sys.argv[2:]) + + print(prefix + "\n" + replace_typing_aliases(file)) From abc9166e2a1727474da96a8c005fc8c46576fff4 Mon Sep 17 00:00:00 2001 From: mivanit Date: Fri, 21 Jun 2024 14:54:33 -0700 Subject: [PATCH 150/158] update deps --- .github/dev-requirements.txt | 4 +-- .github/lint-requirements.txt | 2 +- poetry.lock | 56 +++++++++++++++++------------------ 3 files changed, 31 insertions(+), 31 deletions(-) diff --git a/.github/dev-requirements.txt b/.github/dev-requirements.txt index 7cda1761..71251943 100644 --- a/.github/dev-requirements.txt +++ b/.github/dev-requirements.txt @@ -10,7 +10,7 @@ decorator==5.1.1 ; python_version >= "3.10" and python_version < "4.0" exceptiongroup==1.2.1 ; python_version >= "3.8" and python_version < "3.11" executing==2.0.1 ; python_version >= "3.10" and python_version < "4.0" fonttools==4.53.0 ; python_version >= "3.8" and python_version < "4.0" -importlib-metadata==7.1.0 ; python_version >= "3.8" and python_version < "3.10" +importlib-metadata==7.2.0 ; python_version >= "3.8" and python_version < "3.10" importlib-resources==6.4.0 ; python_version >= "3.8" and python_version < "3.10" iniconfig==2.0.0 ; python_version >= "3.8" and python_version < "4.0" ipython==8.25.0 ; python_version >= "3.10" and python_version < "4.0" @@ -44,7 +44,7 @@ pytest==8.2.2 ; python_version >= "3.8" and python_version < "4.0" python-dateutil==2.9.0.post0 ; python_version >= "3.8" and python_version < "4.0" pyyaml==6.0.1 ; python_version >= "3.8" and python_version < "4" rich==13.7.1 ; python_version >= "3.8" and python_version < "4" -ruff==0.4.9 ; python_version >= "3.8" and python_version < "4.0" +ruff==0.4.10 ; python_version >= "3.8" and python_version < "4.0" shellingham==1.5.4 ; python_version >= "3.8" and python_version < "4" six==1.16.0 ; python_version >= "3.8" and python_version < "4.0" stack-data==0.6.3 ; python_version >= "3.10" and python_version < "4.0" diff --git a/.github/lint-requirements.txt b/.github/lint-requirements.txt index 5c77b646..eac74f60 100644 --- a/.github/lint-requirements.txt +++ b/.github/lint-requirements.txt @@ -9,7 +9,7 @@ pycln==2.4.0 ; python_version >= "3.8" and python_version < "4" pygments==2.18.0 ; python_version >= "3.8" and python_version < "4" pyyaml==6.0.1 ; python_version >= "3.8" and python_version < "4" rich==13.7.1 ; python_version >= "3.8" and python_version < "4" -ruff==0.4.9 ; python_version >= "3.8" and python_version < "4.0" +ruff==0.4.10 ; python_version >= "3.8" and python_version < "4.0" shellingham==1.5.4 ; python_version >= "3.8" and python_version < "4" tomlkit==0.12.5 ; python_version >= "3.8" and python_version < "4" typer==0.12.3 ; python_version >= "3.8" and python_version < "4" diff --git a/poetry.lock b/poetry.lock index cda5808f..dd7df846 100644 --- a/poetry.lock +++ b/poetry.lock @@ -374,22 +374,22 @@ tqdm = ["tqdm"] [[package]] name = "importlib-metadata" -version = "7.1.0" +version = "7.2.0" description = "Read metadata from Python packages" optional = true python-versions = ">=3.8" files = [ - {file = "importlib_metadata-7.1.0-py3-none-any.whl", hash = "sha256:30962b96c0c223483ed6cc7280e7f0199feb01a0e40cfae4d4450fc6fab1f570"}, - {file = "importlib_metadata-7.1.0.tar.gz", hash = "sha256:b78938b926ee8d5f020fc4772d487045805a55ddbad2ecf21c6d60938dc7fcd2"}, + {file = "importlib_metadata-7.2.0-py3-none-any.whl", hash = "sha256:04e4aad329b8b948a5711d394fa8759cb80f009225441b4f2a02bd4d8e5f426c"}, + {file = "importlib_metadata-7.2.0.tar.gz", hash = "sha256:3ff4519071ed42740522d494d04819b666541b9752c43012f85afb2cc220fcc6"}, ] [package.dependencies] zipp = ">=0.5" [package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] perf = ["ipython"] -testing = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] +test = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] [[package]] name = "importlib-resources" @@ -1592,28 +1592,28 @@ jupyter = ["ipywidgets (>=7.5.1,<9)"] [[package]] name = "ruff" -version = "0.4.9" +version = "0.4.10" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.4.9-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b262ed08d036ebe162123170b35703aaf9daffecb698cd367a8d585157732991"}, - {file = "ruff-0.4.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:98ec2775fd2d856dc405635e5ee4ff177920f2141b8e2d9eb5bd6efd50e80317"}, - {file = "ruff-0.4.9-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4555056049d46d8a381f746680db1c46e67ac3b00d714606304077682832998e"}, - {file = "ruff-0.4.9-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e91175fbe48f8a2174c9aad70438fe9cb0a5732c4159b2a10a3565fea2d94cde"}, - {file = "ruff-0.4.9-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0e8e7b95673f22e0efd3571fb5b0cf71a5eaaa3cc8a776584f3b2cc878e46bff"}, - {file = "ruff-0.4.9-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:2d45ddc6d82e1190ea737341326ecbc9a61447ba331b0a8962869fcada758505"}, - {file = "ruff-0.4.9-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:78de3fdb95c4af084087628132336772b1c5044f6e710739d440fc0bccf4d321"}, - {file = "ruff-0.4.9-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:06b60f91bfa5514bb689b500a25ba48e897d18fea14dce14b48a0c40d1635893"}, - {file = "ruff-0.4.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:88bffe9c6a454bf8529f9ab9091c99490578a593cc9f9822b7fc065ee0712a06"}, - {file = "ruff-0.4.9-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:673bddb893f21ab47a8334c8e0ea7fd6598ecc8e698da75bcd12a7b9d0a3206e"}, - {file = "ruff-0.4.9-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8c1aff58c31948cc66d0b22951aa19edb5af0a3af40c936340cd32a8b1ab7438"}, - {file = "ruff-0.4.9-py3-none-musllinux_1_2_i686.whl", hash = "sha256:784d3ec9bd6493c3b720a0b76f741e6c2d7d44f6b2be87f5eef1ae8cc1d54c84"}, - {file = "ruff-0.4.9-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:732dd550bfa5d85af8c3c6cbc47ba5b67c6aed8a89e2f011b908fc88f87649db"}, - {file = "ruff-0.4.9-py3-none-win32.whl", hash = "sha256:8064590fd1a50dcf4909c268b0e7c2498253273309ad3d97e4a752bb9df4f521"}, - {file = "ruff-0.4.9-py3-none-win_amd64.whl", hash = "sha256:e0a22c4157e53d006530c902107c7f550b9233e9706313ab57b892d7197d8e52"}, - {file = "ruff-0.4.9-py3-none-win_arm64.whl", hash = "sha256:5d5460f789ccf4efd43f265a58538a2c24dbce15dbf560676e430375f20a8198"}, - {file = "ruff-0.4.9.tar.gz", hash = "sha256:f1cb0828ac9533ba0135d148d214e284711ede33640465e706772645483427e3"}, + {file = "ruff-0.4.10-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:5c2c4d0859305ac5a16310eec40e4e9a9dec5dcdfbe92697acd99624e8638dac"}, + {file = "ruff-0.4.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:a79489607d1495685cdd911a323a35871abfb7a95d4f98fc6f85e799227ac46e"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1dd1681dfa90a41b8376a61af05cc4dc5ff32c8f14f5fe20dba9ff5deb80cd6"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c75c53bb79d71310dc79fb69eb4902fba804a81f374bc86a9b117a8d077a1784"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18238c80ee3d9100d3535d8eb15a59c4a0753b45cc55f8bf38f38d6a597b9739"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:d8f71885bce242da344989cae08e263de29752f094233f932d4f5cfb4ef36a81"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:330421543bd3222cdfec481e8ff3460e8702ed1e58b494cf9d9e4bf90db52b9d"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9e9b6fb3a37b772628415b00c4fc892f97954275394ed611056a4b8a2631365e"}, + {file = "ruff-0.4.10-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f54c481b39a762d48f64d97351048e842861c6662d63ec599f67d515cb417f6"}, + {file = "ruff-0.4.10-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:67fe086b433b965c22de0b4259ddfe6fa541c95bf418499bedb9ad5fb8d1c631"}, + {file = "ruff-0.4.10-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:acfaaab59543382085f9eb51f8e87bac26bf96b164839955f244d07125a982ef"}, + {file = "ruff-0.4.10-py3-none-musllinux_1_2_i686.whl", hash = "sha256:3cea07079962b2941244191569cf3a05541477286f5cafea638cd3aa94b56815"}, + {file = "ruff-0.4.10-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:338a64ef0748f8c3a80d7f05785930f7965d71ca260904a9321d13be24b79695"}, + {file = "ruff-0.4.10-py3-none-win32.whl", hash = "sha256:ffe3cd2f89cb54561c62e5fa20e8f182c0a444934bf430515a4b422f1ab7b7ca"}, + {file = "ruff-0.4.10-py3-none-win_amd64.whl", hash = "sha256:67f67cef43c55ffc8cc59e8e0b97e9e60b4837c8f21e8ab5ffd5d66e196e25f7"}, + {file = "ruff-0.4.10-py3-none-win_arm64.whl", hash = "sha256:dd1fcee327c20addac7916ca4e2653fbbf2e8388d8a6477ce5b4e986b68ae6c0"}, + {file = "ruff-0.4.10.tar.gz", hash = "sha256:3aa4f2bc388a30d346c56524f7cacca85945ba124945fe489952aadb6b5cd804"}, ] [[package]] @@ -1673,15 +1673,15 @@ mpmath = ">=1.1.0,<1.4.0" [[package]] name = "tbb" -version = "2021.12.0" +version = "2021.13.0" description = "IntelĀ® oneAPI Threading Building Blocks (oneTBB)" optional = true python-versions = "*" files = [ - {file = "tbb-2021.12.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:f2cc9a7f8ababaa506cbff796ce97c3bf91062ba521e15054394f773375d81d8"}, - {file = "tbb-2021.12.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:a925e9a7c77d3a46ae31c34b0bb7f801c4118e857d137b68f68a8e458fcf2bd7"}, - {file = "tbb-2021.12.0-py3-none-win32.whl", hash = "sha256:b1725b30c174048edc8be70bd43bb95473f396ce895d91151a474d0fa9f450a8"}, - {file = "tbb-2021.12.0-py3-none-win_amd64.whl", hash = "sha256:fc2772d850229f2f3df85f1109c4844c495a2db7433d38200959ee9265b34789"}, + {file = "tbb-2021.13.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:a2567725329639519d46d92a2634cf61e76601dac2f777a05686fea546c4fe4f"}, + {file = "tbb-2021.13.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:aaf667e92849adb012b8874d6393282afc318aca4407fc62f912ee30a22da46a"}, + {file = "tbb-2021.13.0-py3-none-win32.whl", hash = "sha256:6669d26703e9943f6164c6407bd4a237a45007e79b8d3832fe6999576eaaa9ef"}, + {file = "tbb-2021.13.0-py3-none-win_amd64.whl", hash = "sha256:3528a53e4bbe64b07a6112b4c5a00ff3c61924ee46c9c68e004a1ac7ad1f09c3"}, ] [[package]] From 2c1db05637196f85c534b2f3bb3508ccb48f006e Mon Sep 17 00:00:00 2001 From: mivanit Date: Fri, 21 Jun 2024 15:11:33 -0700 Subject: [PATCH 151/158] wip --- muutils/validate_type.py | 63 ++++++++++++++++--- .../unit/validate_type/test_validate_type.py | 6 +- 2 files changed, 59 insertions(+), 10 deletions(-) diff --git a/muutils/validate_type.py b/muutils/validate_type.py index 4ab85f2a..af66a35b 100644 --- a/muutils/validate_type.py +++ b/muutils/validate_type.py @@ -18,10 +18,32 @@ GenericAliasTypes: tuple = tuple([t for t in _GenericAliasTypesList if t is not None]) +class IncorrectTypeException(TypeError): + pass + +class TypeHintNotImplementedError(NotImplementedError): + pass + +class InvalidGenericAliasError(TypeError): + pass + + + def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: - """Validate that a value is of the expected type. use `typeguard` for a more robust solution. + """Validate that a `value` is of the `expected_type`. use `typeguard` for a more robust solution. + + # Parameters + - `value`: the value to check the type of + - `expected_type`: the type to check against. Not all types are supported + + # Returns + - `bool`: `True` if the value is of the expected type, `False` otherwise. + + # Raises + - `TypeHintNotImplementedError(NotImplementedError)`: if the type hint is not implemented + - `InvalidGenericAliasError(TypeError)`: if the generic alias is invalid - https://github.com/agronholm/typeguard + typeguard: https://github.com/agronholm/typeguard """ if expected_type is typing.Any: return True @@ -57,7 +79,7 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: return isinstance(value, list) # incorrect number of args if len(args) != 1: - raise TypeError( + raise InvalidGenericAliasError( f"Too many arguments for list expected 1, got {args = }, {expected_type = }, {value = }, {origin = }", f"{GenericAliasTypes = }", ) @@ -74,7 +96,7 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: return isinstance(value, dict) # incorrect number of args if len(args) != 2: - raise TypeError( + raise InvalidGenericAliasError( f"Expected 2 arguments for dict, expected 2, got {args = }, {expected_type = }, {value = }, {origin = }", f"{GenericAliasTypes = }", ) @@ -95,7 +117,7 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: return isinstance(value, set) # incorrect number of args if len(args) != 1: - raise TypeError( + raise InvalidGenericAliasError( f"Expected 1 argument for Set, got {args = }, {expected_type = }, {value = }, {origin = }", f"{GenericAliasTypes = }", ) @@ -125,7 +147,7 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: return isinstance(value, type) # incorrect number of args if len(args) != 1: - raise TypeError( + raise InvalidGenericAliasError( f"Expected 1 argument for Type, got {args = }, {expected_type = }, {value = }, {origin = }", f"{GenericAliasTypes = }", ) @@ -137,13 +159,38 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: # TODO: Callables, etc. - raise NotImplementedError( + raise TypeHintNotImplementedError( f"Unsupported generic alias {expected_type = } for {value = }, {origin = }, {args = }", f"{GenericAliasTypes = }", ) else: - raise NotImplementedError( + raise TypeHintNotImplementedError( f"Unsupported type hint {expected_type = } for {value = }", f"{GenericAliasTypes = }", ) + + +def validate_type_except(value: typing.Any, expected_type: typing.Any) -> None: + """equvalent to `validate_type` but raises an exception if the type is incorrect + + + # Parameters + - `value`: the value to check the type of + - `expected_type`: the type to check against. Not all types are supported + + # Raises + - `IncorrectTypeException(TypeError)`: if the type is incorrect + - `TypeHintNotImplementedError(NotImplementedError)`: if the type hint is not implemented + - `InvalidGenericAliasError(TypeError)`: if the generic alias is invalid + + # Returns `None` + + """ + + if not validate_type(value, expected_type): + raise IncorrectTypeException( + f"Expected {expected_type = } but got {type(value) = } for {value = }", + f"{typing.get_origin(expected_type) = }", + f"{typing.get_args(expected_type) = }", + ) \ No newline at end of file diff --git a/tests/unit/validate_type/test_validate_type.py b/tests/unit/validate_type/test_validate_type.py index 8a2c06f4..f5b2ba2f 100644 --- a/tests/unit/validate_type/test_validate_type.py +++ b/tests/unit/validate_type/test_validate_type.py @@ -5,7 +5,7 @@ import pytest -from muutils.validate_type import validate_type +from muutils.validate_type import IncorrectTypeException, validate_type, validate_type_except # Tests for basic types and common use cases @@ -445,10 +445,12 @@ def __init__(self, a: int, b: str): self.b: str = b assert validate_type(Parent(1, "a"), Parent) - assert validate_type(Child(1, "a"), Parent) + validate_type_except(Child(1, "a"), Parent) assert validate_type(Child(1, "a"), Child) assert not validate_type(Parent(1, "a"), Child) + with pytest.raises(IncorrectTypeException): + validate_type_except(Parent(1, "a"), Child) def test_validate_type_class(): class Parent: From 21492c6b8beb70864bf42fba77508dcb0b1f27b1 Mon Sep 17 00:00:00 2001 From: mivanit Date: Fri, 21 Jun 2024 15:36:00 -0700 Subject: [PATCH 152/158] wow im dumb --- muutils/validate_type.py | 109 ++++++++++-------- .../unit/validate_type/test_validate_type.py | 8 +- 2 files changed, 62 insertions(+), 55 deletions(-) diff --git a/muutils/validate_type.py b/muutils/validate_type.py index af66a35b..0acb5ab1 100644 --- a/muutils/validate_type.py +++ b/muutils/validate_type.py @@ -2,6 +2,7 @@ import types import typing +import functools # this is also for python <3.10 compatibility _GenericAliasTypeNames: typing.List[str] = [ @@ -28,33 +29,61 @@ class InvalidGenericAliasError(TypeError): pass +def _return_validation_except(return_val: bool, value: typing.Any, expected_type: typing.Any) -> bool: + if return_val: + return True + else: + raise IncorrectTypeException( + f"Expected {expected_type = } for {value = }", + f"{type(value) = }", + f"{type(value).__mro__ = }", + f"{typing.get_origin(expected_type) = }", + f"{typing.get_args(expected_type) = }", + "\ndo --tb=long in pytest to see full trace", + ) + return False + +def _return_validation_bool(return_val: bool, *args, **kwargs) -> bool: + return return_val + -def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: - """Validate that a `value` is of the `expected_type`. use `typeguard` for a more robust solution. +def validate_type(value: typing.Any, expected_type: typing.Any, do_except: bool = False) -> bool: + """Validate that a `value` is of the `expected_type` # Parameters - `value`: the value to check the type of - `expected_type`: the type to check against. Not all types are supported + - `do_except`: if `True`, raise an exception if the type is incorrect (instead of returning `False`) + (default: `False`) # Returns - `bool`: `True` if the value is of the expected type, `False` otherwise. # Raises + - `IncorrectTypeException(TypeError)`: if the type is incorrect and `do_except` is `True` - `TypeHintNotImplementedError(NotImplementedError)`: if the type hint is not implemented - `InvalidGenericAliasError(TypeError)`: if the generic alias is invalid - typeguard: https://github.com/agronholm/typeguard + use `typeguard` for a more robust solution: https://github.com/agronholm/typeguard """ if expected_type is typing.Any: return True + + # set up the return function depending on `do_except` + _return_func: typing.Callable[[bool, typing.Any], bool] = ( + functools.partial(_return_validation_except, value=value, expected_type=expected_type) + if do_except + else _return_validation_bool + ) # base type without args if isinstance(expected_type, type): try: # if you use args on a type like `dict[str, int]`, this will fail - return isinstance(value, expected_type) - except TypeError: - pass + return _return_func(isinstance(value, expected_type)) + except TypeError as e: + if isinstance(e, IncorrectTypeException): + raise e origin: typing.Any = typing.get_origin(expected_type) args: tuple = typing.get_args(expected_type) @@ -68,7 +97,7 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: if UnionType is None # return False if UnionType is not available else origin is UnionType # return True if UnionType is available ): - return any(validate_type(value, arg) for arg in args) + return _return_func(any(validate_type(value, arg) for arg in args)) # generic alias, more complicated item_type: type @@ -76,7 +105,7 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: if origin is list: # no args if len(args) == 0: - return isinstance(value, list) + return _return_func(isinstance(value, list)) # incorrect number of args if len(args) != 1: raise InvalidGenericAliasError( @@ -85,7 +114,7 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: ) # check is list if not isinstance(value, list): - return False + return _return_func(False) # check all items in list are of the correct type item_type = args[0] return all(validate_type(item, item_type) for item in value) @@ -93,7 +122,7 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: if origin is dict: # no args if len(args) == 0: - return isinstance(value, dict) + return _return_func(isinstance(value, dict)) # incorrect number of args if len(args) != 2: raise InvalidGenericAliasError( @@ -102,19 +131,19 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: ) # check is dict if not isinstance(value, dict): - return False + return _return_func(False) # check all items in dict are of the correct type key_type: type = args[0] value_type: type = args[1] - return all( + return _return_func(all( validate_type(key, key_type) and validate_type(val, value_type) for key, val in value.items() - ) + )) if origin is set: # no args if len(args) == 0: - return isinstance(value, set) + return _return_func(isinstance(value, set)) # incorrect number of args if len(args) != 1: raise InvalidGenericAliasError( @@ -123,28 +152,28 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: ) # check is set if not isinstance(value, set): - return False + return _return_func(False) # check all items in set are of the correct type item_type = args[0] - return all(validate_type(item, item_type) for item in value) + return _return_func(all(validate_type(item, item_type) for item in value)) if origin is tuple: # no args if len(args) == 0: - return isinstance(value, tuple) + return _return_func(isinstance(value, tuple)) # check is tuple if not isinstance(value, tuple): - return False + return _return_func(False) # check correct number of items in tuple if len(value) != len(args): - return False + return _return_func(False) # check all items in tuple are of the correct type - return all(validate_type(item, arg) for item, arg in zip(value, args)) + return _return_func(all(validate_type(item, arg) for item, arg in zip(value, args))) if origin is type: # no args if len(args) == 0: - return isinstance(value, type) + return _return_func(isinstance(value, type)) # incorrect number of args if len(args) != 1: raise InvalidGenericAliasError( @@ -152,45 +181,23 @@ def validate_type(value: typing.Any, expected_type: typing.Any) -> bool: f"{GenericAliasTypes = }", ) # check is type - if origin in value.__mro__: - return True + item_type = args[0] + if item_type in value.__mro__: + return _return_func(True) else: - return False + return _return_func(False) # TODO: Callables, etc. raise TypeHintNotImplementedError( f"Unsupported generic alias {expected_type = } for {value = }, {origin = }, {args = }", - f"{GenericAliasTypes = }", + f"{origin = }, {args = }", + f"\n{GenericAliasTypes = }", ) else: raise TypeHintNotImplementedError( f"Unsupported type hint {expected_type = } for {value = }", - f"{GenericAliasTypes = }", + f"{origin = }, {args = }", + f"\n{GenericAliasTypes = }", ) - - -def validate_type_except(value: typing.Any, expected_type: typing.Any) -> None: - """equvalent to `validate_type` but raises an exception if the type is incorrect - - - # Parameters - - `value`: the value to check the type of - - `expected_type`: the type to check against. Not all types are supported - - # Raises - - `IncorrectTypeException(TypeError)`: if the type is incorrect - - `TypeHintNotImplementedError(NotImplementedError)`: if the type hint is not implemented - - `InvalidGenericAliasError(TypeError)`: if the generic alias is invalid - - # Returns `None` - - """ - - if not validate_type(value, expected_type): - raise IncorrectTypeException( - f"Expected {expected_type = } but got {type(value) = } for {value = }", - f"{typing.get_origin(expected_type) = }", - f"{typing.get_args(expected_type) = }", - ) \ No newline at end of file diff --git a/tests/unit/validate_type/test_validate_type.py b/tests/unit/validate_type/test_validate_type.py index f5b2ba2f..672f74bb 100644 --- a/tests/unit/validate_type/test_validate_type.py +++ b/tests/unit/validate_type/test_validate_type.py @@ -5,7 +5,7 @@ import pytest -from muutils.validate_type import IncorrectTypeException, validate_type, validate_type_except +from muutils.validate_type import IncorrectTypeException, validate_type # Tests for basic types and common use cases @@ -445,12 +445,12 @@ def __init__(self, a: int, b: str): self.b: str = b assert validate_type(Parent(1, "a"), Parent) - validate_type_except(Child(1, "a"), Parent) + validate_type(Child(1, "a"), Parent, do_except=True) assert validate_type(Child(1, "a"), Child) assert not validate_type(Parent(1, "a"), Child) with pytest.raises(IncorrectTypeException): - validate_type_except(Parent(1, "a"), Child) + validate_type(Parent(1, "a"), Child, do_except=True) def test_validate_type_class(): class Parent: @@ -465,7 +465,7 @@ def __init__(self, a: int, b: str): assert validate_type(Parent, type) assert validate_type(Child, type) - assert validate_type(Parent, typing.Type[Parent]) + assert validate_type(Parent, typing.Type[Parent], do_except=True) assert validate_type(Child, typing.Type[Child]) assert not validate_type(Parent, typing.Type[Child]) From c2ff8efa408095e42b31b4a7edf6f42edffed180 Mon Sep 17 00:00:00 2001 From: mivanit Date: Fri, 21 Jun 2024 15:36:24 -0700 Subject: [PATCH 153/158] yes i was literally checking for inheritance on unrelated classes --- tests/unit/validate_type/test_validate_type.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/validate_type/test_validate_type.py b/tests/unit/validate_type/test_validate_type.py index 672f74bb..2dab7c26 100644 --- a/tests/unit/validate_type/test_validate_type.py +++ b/tests/unit/validate_type/test_validate_type.py @@ -439,7 +439,7 @@ def __init__(self, a: int, b: str): self.a: int = a self.b: str = b - class Child: + class Child(Parent): def __init__(self, a: int, b: str): self.a: int = 2 * a self.b: str = b @@ -458,7 +458,7 @@ def __init__(self, a: int, b: str): self.a: int = a self.b: str = b - class Child: + class Child(Parent): def __init__(self, a: int, b: str): self.a: int = 2 * a self.b: str = b @@ -480,7 +480,7 @@ def __init__(self, a: int, b: str): self.a: int = a self.b: str = b - class Child: + class Child(Parent): def __init__(self, a: int, b: str): self.a: int = 2 * a self.b: str = b From 2233c05028829b513c634f326c081fc37f87384c Mon Sep 17 00:00:00 2001 From: mivanit Date: Fri, 21 Jun 2024 15:37:09 -0700 Subject: [PATCH 154/158] format --- muutils/validate_type.py | 33 +++++++++++++------ .../unit/validate_type/test_validate_type.py | 1 + 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/muutils/validate_type.py b/muutils/validate_type.py index 0acb5ab1..2a9c3dbe 100644 --- a/muutils/validate_type.py +++ b/muutils/validate_type.py @@ -22,14 +22,18 @@ class IncorrectTypeException(TypeError): pass + class TypeHintNotImplementedError(NotImplementedError): pass + class InvalidGenericAliasError(TypeError): pass -def _return_validation_except(return_val: bool, value: typing.Any, expected_type: typing.Any) -> bool: +def _return_validation_except( + return_val: bool, value: typing.Any, expected_type: typing.Any +) -> bool: if return_val: return True else: @@ -43,11 +47,14 @@ def _return_validation_except(return_val: bool, value: typing.Any, expected_type ) return False + def _return_validation_bool(return_val: bool, *args, **kwargs) -> bool: return return_val -def validate_type(value: typing.Any, expected_type: typing.Any, do_except: bool = False) -> bool: +def validate_type( + value: typing.Any, expected_type: typing.Any, do_except: bool = False +) -> bool: """Validate that a `value` is of the `expected_type` # Parameters @@ -68,11 +75,13 @@ def validate_type(value: typing.Any, expected_type: typing.Any, do_except: bool """ if expected_type is typing.Any: return True - + # set up the return function depending on `do_except` _return_func: typing.Callable[[bool, typing.Any], bool] = ( - functools.partial(_return_validation_except, value=value, expected_type=expected_type) - if do_except + functools.partial( + _return_validation_except, value=value, expected_type=expected_type + ) + if do_except else _return_validation_bool ) @@ -135,10 +144,12 @@ def validate_type(value: typing.Any, expected_type: typing.Any, do_except: bool # check all items in dict are of the correct type key_type: type = args[0] value_type: type = args[1] - return _return_func(all( - validate_type(key, key_type) and validate_type(val, value_type) - for key, val in value.items() - )) + return _return_func( + all( + validate_type(key, key_type) and validate_type(val, value_type) + for key, val in value.items() + ) + ) if origin is set: # no args @@ -168,7 +179,9 @@ def validate_type(value: typing.Any, expected_type: typing.Any, do_except: bool if len(value) != len(args): return _return_func(False) # check all items in tuple are of the correct type - return _return_func(all(validate_type(item, arg) for item, arg in zip(value, args))) + return _return_func( + all(validate_type(item, arg) for item, arg in zip(value, args)) + ) if origin is type: # no args diff --git a/tests/unit/validate_type/test_validate_type.py b/tests/unit/validate_type/test_validate_type.py index 2dab7c26..0bb8de63 100644 --- a/tests/unit/validate_type/test_validate_type.py +++ b/tests/unit/validate_type/test_validate_type.py @@ -452,6 +452,7 @@ def __init__(self, a: int, b: str): with pytest.raises(IncorrectTypeException): validate_type(Parent(1, "a"), Child, do_except=True) + def test_validate_type_class(): class Parent: def __init__(self, a: int, b: str): From 73df7fa186a9efd8098bb897536bcbb4b9fcbe99 Mon Sep 17 00:00:00 2001 From: mivanit Date: Fri, 21 Jun 2024 15:41:06 -0700 Subject: [PATCH 155/158] typing fix --- muutils/validate_type.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/muutils/validate_type.py b/muutils/validate_type.py index 2a9c3dbe..cb5a29fd 100644 --- a/muutils/validate_type.py +++ b/muutils/validate_type.py @@ -48,7 +48,7 @@ def _return_validation_except( return False -def _return_validation_bool(return_val: bool, *args, **kwargs) -> bool: +def _return_validation_bool(return_val: bool) -> bool: return return_val @@ -77,8 +77,9 @@ def validate_type( return True # set up the return function depending on `do_except` - _return_func: typing.Callable[[bool, typing.Any], bool] = ( - functools.partial( + _return_func: typing.Callable[[bool], bool] = ( + # functools.partial doesn't hint the function signature + functools.partial( # type: ignore[assignment] _return_validation_except, value=value, expected_type=expected_type ) if do_except From 3735580ae27e965bba18343471b89e3454bca4cd Mon Sep 17 00:00:00 2001 From: mivanit Date: Fri, 21 Jun 2024 15:57:13 -0700 Subject: [PATCH 156/158] add deserialize_fn as clearer alternative to loading_fn resolves #40 --- .../json_serialize/serializable_dataclass.py | 6 ++- muutils/json_serialize/serializable_field.py | 38 +++++++++++++++++-- .../test_serializable_dataclass.py | 20 ++++++++++ 3 files changed, 59 insertions(+), 5 deletions(-) diff --git a/muutils/json_serialize/serializable_dataclass.py b/muutils/json_serialize/serializable_dataclass.py index 95178f42..47e35253 100644 --- a/muutils/json_serialize/serializable_dataclass.py +++ b/muutils/json_serialize/serializable_dataclass.py @@ -608,7 +608,11 @@ def load(cls, data: dict[str, Any] | T) -> Type[T]: # get the type hint for the field field_type_hint: Any = cls_type_hints.get(field.name, None) - if field.loading_fn: + # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set + if field.deserialize_fn: + # if it has a deserialization function, use that + value = field.deserialize_fn(value) + elif field.loading_fn: # if it has a loading function, use that value = field.loading_fn(data) elif ( diff --git a/muutils/json_serialize/serializable_field.py b/muutils/json_serialize/serializable_field.py index 54896d6f..814188ef 100644 --- a/muutils/json_serialize/serializable_field.py +++ b/muutils/json_serialize/serializable_field.py @@ -29,6 +29,7 @@ class SerializableField(dataclasses.Field): "serialize", "serialization_fn", "loading_fn", + "deserialize_fn", # new alternative to loading_fn "assert_type", "custom_typecheck_fn", ) @@ -49,6 +50,7 @@ def __init__( serialize: bool = True, serialization_fn: Optional[Callable[[Any], Any]] = None, loading_fn: Optional[Callable[[Any], Any]] = None, + deserialize_fn: Optional[Callable[[Any], Any]] = None, assert_type: bool = True, custom_typecheck_fn: Optional[Callable[[type], bool]] = None, ): @@ -85,7 +87,16 @@ def __init__( # now init the new fields self.serialize: bool = serialize self.serialization_fn: Optional[Callable[[Any], Any]] = serialization_fn + + if loading_fn is not None and deserialize_fn is not None: + raise ValueError( + "Cannot pass both loading_fn and deserialize_fn, pass only one. ", + "`loading_fn` is the older interface and takes the dict of the class, ", + "`deserialize_fn` is the new interface and takes only the field's value.", + ) self.loading_fn: Optional[Callable[[Any], Any]] = loading_fn + self.deserialize_fn: Optional[Callable[[Any], Any]] = deserialize_fn + self.assert_type: bool = assert_type self.custom_typecheck_fn: Optional[Callable[[type], bool]] = custom_typecheck_fn @@ -101,9 +112,10 @@ def from_Field(cls, field: dataclasses.Field) -> "SerializableField": compare=field.compare, metadata=field.metadata, kw_only=getattr(field, "kw_only", dataclasses.MISSING), # for python <3.9 - serialize=field.repr, + serialize=field.repr, # serialize if it's going to be repr'd serialization_fn=None, loading_fn=None, + deserialize_fn=None, ) @@ -121,10 +133,12 @@ def serializable_field(*args, **kwargs): # -> SerializableField: compare: bool = True, metadata: types.MappingProxyType | None = None, kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING, + # ---------------------------------------------------------------------- # new in `SerializableField`, not in `dataclasses.Field` serialize: bool = True, serialization_fn: Optional[Callable[[Any], Any]] = None, loading_fn: Optional[Callable[[Any], Any]] = None, + deserialize_fn: Optional[Callable[[Any], Any]] = None, assert_type: bool = True, custom_typecheck_fn: Optional[Callable[[type], bool]] = None, ``` @@ -133,16 +147,32 @@ def serializable_field(*args, **kwargs): # -> SerializableField: - `serialize`: whether to serialize this field when serializing the class' - `serialization_fn`: function taking the instance of the field and returning a serializable object. If not provided, will iterate through the `SerializerHandler`s defined in `muutils.json_serialize.json_serialize` - `loading_fn`: function taking the serialized object and returning the instance of the field. If not provided, will take object as-is. + - `deserialize_fn`: new alternative to `loading_fn`. takes only the field's value, not the whole class. if both `loading_fn` and `deserialize_fn` are provided, an error will be raised. # Gotchas: - `loading_fn` takes the dict of the **class**, not the field. if you wanted a `loading_fn` that does nothing, you'd write: + ```python class MyClass: - my_field: int = serializable_field(loading_fn=lambda x["my_field"]: x) + my_field: int = serializable_field( + serialization_fn=lambda x: str(x), + loading_fn=lambda x["my_field"]: int(x) + ) ``` - issue to add a different way of doing this: https://github.com/mivanit/muutils/issues/40 - note that if not using ZANJ, and you have a class inside a container, you MUST provide + using `deserialize_fn` instead: + + ```python + class MyClass: + my_field: int = serializable_field( + serialization_fn=lambda x: str(x), + deserialize_fn=lambda x: int(x) + ) + ``` + + In the above code, `my_field` is an int but will be serialized as a string. + + note that if not using ZANJ, and you have a class inside a container, you MUST provide `serialization_fn` and `loading_fn` to serialize and load the container. ZANJ will automatically do this for you. """ diff --git a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py index 18f7254e..7479a82a 100644 --- a/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py +++ b/tests/unit/json_serialize/serializable_dataclass/test_serializable_dataclass.py @@ -380,3 +380,23 @@ def test_nested_custom(recwarn): # this will send some warnings but whatever assert serialized == expected_ser loaded = nested_custom.load(serialized) assert loaded == instance + + +def test_deserialize_fn(): + @serializable_dataclass + class DeserializeFn(SerializableDataclass): + data: int = serializable_field( + serialization_fn=lambda x: str(x), + deserialize_fn=lambda x: int(x), + ) + + instance = DeserializeFn(data=5) + serialized = instance.serialize() + assert serialized == { + "data": "5", + "__format__": "DeserializeFn(SerializableDataclass)", + } + + loaded = DeserializeFn.load(serialized) + assert loaded == instance + assert loaded.data == 5 From 4dcda67d2afdec2e80954006919aa56f9c79fbba Mon Sep 17 00:00:00 2001 From: mivanit Date: Fri, 21 Jun 2024 16:11:08 -0700 Subject: [PATCH 157/158] update coverage --- docs/coverage/coverage.svg | 4 +- docs/coverage/coverage.txt | 89 ++++++++++++++++++++------------------ 2 files changed, 50 insertions(+), 43 deletions(-) diff --git a/docs/coverage/coverage.svg b/docs/coverage/coverage.svg index 90299371..b750dd9c 100644 --- a/docs/coverage/coverage.svg +++ b/docs/coverage/coverage.svg @@ -15,7 +15,7 @@ coverage coverage - 76% - 76% + 77% + 77% diff --git a/docs/coverage/coverage.txt b/docs/coverage/coverage.txt index b1b1c67f..61f3bc10 100644 --- a/docs/coverage/coverage.txt +++ b/docs/coverage/coverage.txt @@ -1,55 +1,62 @@ Name Stmts Miss Cover Missing --------------------------------------------------------------------------------------------------------------- -muutils\__init__.py 0 0 100% -muutils\dictmagic.py 162 23 86% 14-19, 22-25, 177, 285, 445-449, 453, 486-498 -muutils\group_equiv.py 28 0 100% -muutils\json_serialize\__init__.py 5 5 0% 1-15 -muutils\json_serialize\array.py 80 30 62% 1-14, 18-24, 35, 74, 102-105, 114, 118, 121, 125, 129-132, 136, 143, 155, 174-179, 184, 187 -muutils\json_serialize\json_serialize.py 64 45 30% 1-80, 124, 208-234, 244, 256, 271-304 -muutils\json_serialize\serializable_dataclass.py 187 71 62% 1-35, 69, 82-102, 129, 147, 160-165, 231-248, 257-267, 288-325, 329, 339-340, 344, 353-357, 361, 370, 381-389, 417, 443, 459-460, 483, 510, 516 -muutils\json_serialize\util.py 76 48 37% 1-58, 62, 73, 77, 93, 102, 107-111, 122, 125-126 -muutils\jsonlines.py 31 31 0% 1-73 -muutils\kappa.py 14 0 100% +muutils\__init__.py 1 1 0% 1 +muutils\dictmagic.py 158 21 87% 28-33, 36-39, 297, 457-461, 465, 498-510 +muutils\errormode.py 32 13 59% 1-13, 30-33, 49 +muutils\group_equiv.py 29 0 100% +muutils\json_serialize\__init__.py 6 6 0% 1-17 +muutils\json_serialize\array.py 92 33 64% 1-24, 28-34, 45, 85, 119-122, 131, 135, 139, 142, 146, 150-153, 157, 164, 176, 203-208, 213, 216 +muutils\json_serialize\json_serialize.py 63 45 29% 1-81, 125, 209-235, 245, 257, 272-305 +muutils\json_serialize\serializable_dataclass.py 216 87 60% 1-36, 50-55, 75-83, 108, 114, 144, 148-171, 190-192, 196, 208, 220-242, 251-257, 286, 299-300, 306, 317-321, 326, 338, 355-366, 372-384, 394, 485, 493-497, 526, 550-551, 585, 627 +muutils\json_serialize\serializable_field.py 32 17 47% 1-37, 73, 79-82, 92, 103-123 +muutils\json_serialize\util.py 109 63 42% 1-43, 45, 49-61, 65, 76, 80, 96, 105, 110-114, 125, 128-133, 151, 164-169, 235-252 +muutils\jsonlines.py 32 32 0% 1-75 +muutils\kappa.py 15 0 100% muutils\logger\__init__.py 5 0 100% muutils\logger\exception_context.py 12 6 50% 24, 27, 30-43 -muutils\logger\headerfuncs.py 18 1 94% 53 +muutils\logger\headerfuncs.py 19 1 95% 55 muutils\logger\log_util.py 32 32 0% 1-80 -muutils\logger\logger.py 97 25 74% 26-34, 85, 88, 133, 153-154, 192, 225, 235, 255-259, 275-278, 293, 297, 304 -muutils\logger\loggingstream.py 39 12 69% 41-74, 79, 89-90 -muutils\logger\simplelogger.py 40 19 52% 14, 18, 22, 26, 53-63, 67-79 -muutils\logger\timing.py 39 19 51% 25-28, 41-46, 50-52, 65-68, 79-84 -muutils\misc.py 164 12 93% 155, 186, 227-229, 288, 297, 300, 320-321, 334, 373-374 -muutils\mlutils.py 66 39 41% 1-26, 29, 35-49, 54, 56, 65-73, 96, 104, 128-131, 141-142, 152-153 +muutils\logger\logger.py 98 25 74% 28-36, 87, 90, 135, 155-156, 194, 227, 237, 257-261, 277-280, 295, 299, 306 +muutils\logger\loggingstream.py 40 12 70% 43-76, 81, 91-92 +muutils\logger\simplelogger.py 41 19 54% 16, 20, 24, 28, 55-65, 69-81 +muutils\logger\timing.py 39 18 54% 27-30, 43-48, 52-54, 67-70, 81-87 +muutils\misc.py 172 11 94% 210, 241, 284, 342, 351, 354, 374-375, 388, 427-428 +muutils\mlutils.py 72 43 40% 1-28, 31, 37-51, 56, 58, 67-75, 98, 106, 128-131, 142-147, 151-152, 162-163 muutils\nbutils\__init__.py 2 2 0% 1-3 -muutils\nbutils\configure_notebook.py 132 79 40% 1-55, 73-82, 96, 106-107, 112, 115-116, 136-139, 144, 150-159, 166-171, 180, 219-230, 236-241, 264-271, 274 -muutils\nbutils\convert_ipynb_to_script.py 118 41 65% 63, 78, 91, 105-139, 228-230, 236, 263, 287-289, 296-349 +muutils\nbutils\configure_notebook.py 133 80 40% 1-57, 75-84, 106, 116-117, 122, 125-126, 146-149, 154, 160-169, 176-181, 190, 229-240, 246-251, 274-281, 284 +muutils\nbutils\convert_ipynb_to_script.py 119 41 66% 65, 80, 93, 107-141, 230-232, 238, 265, 289-291, 298-351 muutils\nbutils\mermaid.py 11 11 0% 1-18 muutils\nbutils\print_tex.py 10 10 0% 1-19 -muutils\nbutils\run_notebook_tests.py 58 20 66% 29, 31, 35, 39, 45, 53, 80-82, 86-90, 97-114 -muutils\statcounter.py 87 32 63% 24-35, 50, 70, 98, 110, 120, 136-166, 174, 183, 186, 190-195, 204 -muutils\sysinfo.py 71 18 75% 21, 60-61, 78-111, 145, 168 -muutils\tensor_utils.py 125 19 85% 83, 86, 105, 109, 119, 130, 133-136, 144, 151, 165-173 -tests\unit\json_serialize\serializable_dataclass\test_helpers.py 42 0 100% -tests\unit\json_serialize\serializable_dataclass\test_sdc_defaults.py 31 0 100% -tests\unit\json_serialize\serializable_dataclass\test_sdc_properties_nested.py 26 0 100% -tests\unit\json_serialize\serializable_dataclass\test_serializable_dataclass.py 190 0 100% +muutils\nbutils\run_notebook_tests.py 59 20 66% 29, 31, 35, 39, 45, 53, 79-81, 85-89, 96-113 +muutils\statcounter.py 89 32 64% 25-36, 51, 75, 103, 115, 125, 141-171, 179, 188, 191, 195-200, 209 +muutils\sysinfo.py 78 25 68% 20-23, 64-65, 82-115, 156-165, 177, 195-197 +muutils\tensor_utils.py 124 18 85% 86, 89, 108, 112, 129, 132-135, 143, 150, 164-172 +muutils\validate_type.py 82 19 77% 1-34, 51, 55, 190, 193, 213 +tests\unit\json_serialize\serializable_dataclass\test_helpers.py 43 0 100% +tests\unit\json_serialize\serializable_dataclass\test_sdc_defaults.py 32 0 100% +tests\unit\json_serialize\serializable_dataclass\test_sdc_properties_nested.py 44 1 98% 44 +tests\unit\json_serialize\serializable_dataclass\test_serializable_dataclass.py 204 0 100% tests\unit\json_serialize\test_array.py 40 0 100% tests\unit\json_serialize\test_util.py 49 2 96% 66, 73 -tests\unit\logger\test_logger.py 10 0 100% -tests\unit\logger\test_timer_context.py 9 0 100% -tests\unit\misc\test_freeze.py 120 0 100% -tests\unit\misc\test_misc.py 43 0 100% -tests\unit\misc\test_numerical_conversions.py 42 0 100% -tests\unit\nbutils\test_configure_notebook.py 61 0 100% -tests\unit\nbutils\test_conversion.py 26 0 100% +tests\unit\logger\test_logger.py 11 0 100% +tests\unit\logger\test_timer_context.py 11 0 100% +tests\unit\misc\test_freeze.py 121 0 100% +tests\unit\misc\test_misc.py 74 0 100% +tests\unit\misc\test_numerical_conversions.py 43 0 100% +tests\unit\nbutils\test_configure_notebook.py 70 0 100% +tests\unit\nbutils\test_conversion.py 27 0 100% tests\unit\test_chunks.py 31 0 100% -tests\unit\test_dictmagic.py 129 0 100% -tests\unit\test_group_equiv.py 12 0 100% +tests\unit\test_dictmagic.py 130 0 100% +tests\unit\test_errormode.py 69 1 99% 122 +tests\unit\test_group_equiv.py 13 0 100% tests\unit\test_import_torch.py 4 0 100% tests\unit\test_kappa.py 39 0 100% -tests\unit\test_mlutils.py 35 3 91% 31, 35, 43 -tests\unit\test_statcounter.py 13 0 100% -tests\unit\test_sysinfo.py 4 0 100% -tests\unit\test_tensor_utils.py 48 0 100% +tests\unit\test_mlutils.py 43 6 86% 35, 39, 47, 50, 57-58 +tests\unit\test_statcounter.py 14 0 100% +tests\unit\test_sysinfo.py 6 0 100% +tests\unit\test_tensor_utils.py 51 0 100% +tests\unit\validate_type\test_validate_type.py 206 45 78% 49-50, 74-75, 101-102, 124-125, 146-147, 176-177, 202-203, 235-236, 253, 272-273, 334-335, 359-360, 430-431, 459-460, 464-465, 479-497 +tests\unit\validate_type\test_validate_type_GENERATED.py 206 45 78% 50-51, 75-76, 102-103, 125-126, 147-148, 177-178, 203-204, 236-237, 254, 273-274, 335-336, 360-361, 431-432, 460-461, 465-466, 480-498 +tests\unit\validate_type\test_validate_type_special.py 15 3 80% 34-35, 57 --------------------------------------------------------------------------------------------------------------- -TOTAL 2777 655 76% +TOTAL 3618 846 77% From 3dac41f9fa0cb70f93277a91f8f051f5ee97ca70 Mon Sep 17 00:00:00 2001 From: mivanit Date: Fri, 21 Jun 2024 16:28:57 -0700 Subject: [PATCH 158/158] update deps --- poetry.lock | 10 +++++----- pyproject.toml | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/poetry.lock b/poetry.lock index dd7df846..d3e33158 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1888,17 +1888,17 @@ files = [ [[package]] name = "zanj" -version = "0.2.2" +version = "0.2.3" description = "save and load complex objects to disk without pickling" optional = true python-versions = "<4.0,>=3.10" files = [ - {file = "zanj-0.2.2-py3-none-any.whl", hash = "sha256:91cce89bf8e7041e8acca3071b935899edad0ff776bf4e7bf0d98cfa6cb28f1d"}, - {file = "zanj-0.2.2.tar.gz", hash = "sha256:71c6110f9b9d1a0fe04c011156b8e2c96f6fac1a6f9dc690b6c380862fd5fb8d"}, + {file = "zanj-0.2.3-py3-none-any.whl", hash = "sha256:4992eba4b7b48264da9a243922d69f9689c1d6b11e7915b49e84aeda61b43297"}, + {file = "zanj-0.2.3.tar.gz", hash = "sha256:d0d45ddda07911723d2c2e57db84b1613b1498f66eb418af197d192c89fb87e9"}, ] [package.dependencies] -muutils = {version = ">=0.5.1,<0.6.0", extras = ["array"]} +muutils = {version = ">=0.5.1,<0.7.0", extras = ["array"]} [package.extras] pandas = ["pandas (>=1.5.3)"] @@ -1927,4 +1927,4 @@ zanj = ["zanj"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "ba5985c876532a581082f5dbc55e3ca2237cf82f08044181d55b19a001b00371" +content-hash = "5d222b1676114ac943bd9abb6048a1eed168c24d32be9ca026bde19f10371702" diff --git a/pyproject.toml b/pyproject.toml index e293d6f4..19abb2d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "muutils" -version = "0.5.13" +version = "0.6.0" description = "miscellaneous python utilities" license = "GPL-3.0-only" authors = ["mivanit "] @@ -31,7 +31,7 @@ jaxtyping = { version = "^0.2.12", optional = true } # [notebook] ipython = { version = "^8.20.0", optional = true, python = "^3.10" } # [zanj] -zanj = { version = "^0.2.2", optional = true, python = "^3.10" } +zanj = { version = "^0.2.3", optional = true, python = "^3.10" } [tool.poetry.group.dev.dependencies] # typing