diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e9caef5a..414f1910 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,8 +38,8 @@ repos: language: python types_or: [python, pyi] exclude: ^auto_tutorials_source/ - - id: pytest-check - name: pytest-check - entry: pytest - language: system - pass_filenames: false + # - id: pytest-check + # name: pytest-check + # entry: pytest + # language: system + # pass_filenames: false diff --git a/README.md b/README.md index 11c5738d..4ffc9274 100644 --- a/README.md +++ b/README.md @@ -51,20 +51,26 @@ A quickstart is available at [torch-uncertainty.github.io/quickstart](https://to To date, the following deep learning baselines have been implemented: - Deep Ensembles -- MC-Dropout +- MC-Dropout - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_mc_dropout.html) - BatchEnsemble - Masksembles - MIMO -- Packed-Ensembles (see [blog post](https://medium.com/@adrien.lafage/make-your-neural-networks-more-reliable-with-packed-ensembles-7ad0b737a873)) -- Bayesian Neural Networks :construction: Work in progress :construction: +- Packed-Ensembles (see [blog post](https://medium.com/@adrien.lafage/make-your-neural-networks-more-reliable-with-packed-ensembles-7ad0b737a873)) - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_pe_cifar10.html) +- Bayesian Neural Networks :construction: Work in progress :construction: - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_bayesian.html) - Regression with Beta Gaussian NLL Loss -- Deep Evidential Classification & Regression +- Deep Evidential Classification & Regression - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_evidential_classification.html) + +### Augmentation methods + +The following data augmentation methods have been implemented: + +- Mixup, MixupIO, RegMixup, WarpingMixup ### Post-processing methods To date, the following post-processing methods have been implemented: -- Temperature, Vector, & Matrix scaling +- Temperature, Vector, & Matrix scaling - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_scaler.html) ## Tutorials diff --git a/docs/source/references.rst b/docs/source/references.rst index 80432303..456a9627 100644 --- a/docs/source/references.rst +++ b/docs/source/references.rst @@ -114,6 +114,39 @@ For Monte-Carlo Dropout, consider citing: * Authors: *Yarin Gal and Zoubin Ghahramani* * Paper: `ICML 2016 `__. +Data Augmentation Methods +------------------------- + +Mixup +^^^^^ + +For Mixup, consider citing: + +**mixup: Beyond Empirical Risk Minimization** + +* Authors: *Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, and David Lopez-Paz* +* Paper: `ICLR 2018 `__. + +MixupIO +^^^^^^^ + +For MixupIO, consider citing: + +**On the Pitfall of Mixup for Uncertainty Calibration** + +* Authors: *Deng-Bao Wang, Lanqing Li, Peilin Zhao, Pheng-Ann Heng, and Min-Ling Zhang* +* Paper: `CVPR 2023 ` + +Warping Mixup +^^^^^^^^^^^^^ + +For Warping Mixup, consider citing: + +**Tailoring Mixup to Data using Kernel Warping functions** + +* Authors: *Quentin Bouniot, Pavlo Mozharovskyi, and Florence d'Alché-Buc* +* Paper: `ArXiv 2023 `__. + Post-Processing Methods ----------------------- diff --git a/experiments/classification/cifar10/resnet.py b/experiments/classification/cifar10/resnet.py index fe555e95..1540c0bf 100644 --- a/experiments/classification/cifar10/resnet.py +++ b/experiments/classification/cifar10/resnet.py @@ -6,6 +6,7 @@ from torch_uncertainty.baselines import ResNet from torch_uncertainty.datamodules import CIFAR10DataModule from torch_uncertainty.optimization_procedures import get_procedure +from torch_uncertainty.utils import csv_writer if __name__ == "__main__": args = init_args(ResNet, CIFAR10DataModule) @@ -14,22 +15,59 @@ else: root = Path(args.root) - net_name = f"{args.version}-resnet{args.arch}-cifar10" + if args.exp_name == "": + args.exp_name = f"{args.version}-resnet{args.arch}-cifar10" # datamodule args.root = str(root / "data") dm = CIFAR10DataModule(**vars(args)) - # model - model = ResNet( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=get_procedure( - f"resnet{args.arch}", "cifar10", args.version - ), - style="cifar", - **vars(args), - ) - - cli_main(model, dm, root, net_name, args) + if args.opt_temp_scaling: + calibration_set = dm.get_test_set + elif args.val_temp_scaling: + calibration_set = dm.get_val_set + else: + calibration_set = None + + if args.use_cv: + list_dm = dm.make_cross_val_splits(args.n_splits, args.train_over) + list_model = [] + for i in range(len(list_dm)): + list_model.append( + ResNet( + num_classes=list_dm[i].dm.num_classes, + in_channels=list_dm[i].dm.num_channels, + loss=nn.CrossEntropyLoss, + optimization_procedure=get_procedure( + f"resnet{args.arch}", "cifar10", args.version + ), + style="cifar", + calibration_set=calibration_set, + **vars(args), + ) + ) + + results = cli_main( + list_model, list_dm, args.exp_dir, args.exp_name, args + ) + else: + # model + model = ResNet( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=nn.CrossEntropyLoss, + optimization_procedure=get_procedure( + f"resnet{args.arch}", "cifar10", args.version + ), + style="cifar", + calibration_set=calibration_set, + **vars(args), + ) + + results = cli_main(model, dm, args.exp_dir, args.exp_name, args) + + for dict_result in results: + csv_writer( + Path(args.exp_dir) / Path(args.exp_name) / "results.csv", + dict_result, + ) diff --git a/experiments/classification/tiny-imagenet/resnet.py b/experiments/classification/tiny-imagenet/resnet.py index db740472..2503b3da 100644 --- a/experiments/classification/tiny-imagenet/resnet.py +++ b/experiments/classification/tiny-imagenet/resnet.py @@ -5,6 +5,8 @@ from torch_uncertainty import cli_main, init_args from torch_uncertainty.baselines import ResNet from torch_uncertainty.datamodules import TinyImageNetDataModule +from torch_uncertainty.optimization_procedures import get_procedure +from torch_uncertainty.utils import csv_writer def optim_tiny(model: nn.Module) -> dict: @@ -26,20 +28,60 @@ def optim_tiny(model: nn.Module) -> dict: else: root = Path(args.root) - net_name = f"{args.version}-resnet{args.arch}-tiny-imagenet" + # net_name = f"{args.version}-resnet{args.arch}-tiny-imagenet" + if args.exp_name == "": + args.exp_name = f"{args.version}-resnet{args.arch}-cifar10" # datamodule args.root = str(root / "data") dm = TinyImageNetDataModule(**vars(args)) - # model - model = ResNet( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_tiny, - style="cifar", - **vars(args), - ) + if args.opt_temp_scaling: + calibration_set = dm.get_test_set + elif args.val_temp_scaling: + calibration_set = dm.get_val_set + else: + calibration_set = None + + if args.use_cv: + list_dm = dm.make_cross_val_splits(args.n_splits, args.train_over) + list_model = [] + for i in range(len(list_dm)): + list_model.append( + ResNet( + num_classes=list_dm[i].dm.num_classes, + in_channels=list_dm[i].dm.num_channels, + loss=nn.CrossEntropyLoss, + optimization_procedure=get_procedure( + f"resnet{args.arch}", "tiny-imagenet", args.version + ), + style="cifar", + calibration_set=calibration_set, + **vars(args), + ) + ) + + results = cli_main( + list_model, list_dm, args.exp_dir, args.exp_name, args + ) + else: + # model + model = ResNet( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=nn.CrossEntropyLoss, + optimization_procedure=get_procedure( + f"resnet{args.arch}", "tiny-imagenet", args.version + ), + calibration_set=calibration_set, + style="cifar", + **vars(args), + ) + + results = cli_main(model, dm, args.exp_dir, args.exp_name, args) - cli_main(model, dm, root, net_name, args) + for dict_result in results: + csv_writer( + Path(args.exp_dir) / Path(args.exp_name) / "results.csv", + dict_result, + ) diff --git a/poetry.lock b/poetry.lock index 39406cf2..fe9989d4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -240,101 +240,101 @@ files = [ [[package]] name = "charset-normalizer" -version = "3.3.0" +version = "3.3.1" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." optional = false python-versions = ">=3.7.0" files = [ - {file = "charset-normalizer-3.3.0.tar.gz", hash = "sha256:63563193aec44bce707e0c5ca64ff69fa72ed7cf34ce6e11d5127555756fd2f6"}, - {file = "charset_normalizer-3.3.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:effe5406c9bd748a871dbcaf3ac69167c38d72db8c9baf3ff954c344f31c4cbe"}, - {file = "charset_normalizer-3.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4162918ef3098851fcd8a628bf9b6a98d10c380725df9e04caf5ca6dd48c847a"}, - {file = "charset_normalizer-3.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0570d21da019941634a531444364f2482e8db0b3425fcd5ac0c36565a64142c8"}, - {file = "charset_normalizer-3.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5707a746c6083a3a74b46b3a631d78d129edab06195a92a8ece755aac25a3f3d"}, - {file = "charset_normalizer-3.3.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:278c296c6f96fa686d74eb449ea1697f3c03dc28b75f873b65b5201806346a69"}, - {file = "charset_normalizer-3.3.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a4b71f4d1765639372a3b32d2638197f5cd5221b19531f9245fcc9ee62d38f56"}, - {file = "charset_normalizer-3.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5969baeaea61c97efa706b9b107dcba02784b1601c74ac84f2a532ea079403e"}, - {file = "charset_normalizer-3.3.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a3f93dab657839dfa61025056606600a11d0b696d79386f974e459a3fbc568ec"}, - {file = "charset_normalizer-3.3.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:db756e48f9c5c607b5e33dd36b1d5872d0422e960145b08ab0ec7fd420e9d649"}, - {file = "charset_normalizer-3.3.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:232ac332403e37e4a03d209a3f92ed9071f7d3dbda70e2a5e9cff1c4ba9f0678"}, - {file = "charset_normalizer-3.3.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e5c1502d4ace69a179305abb3f0bb6141cbe4714bc9b31d427329a95acfc8bdd"}, - {file = "charset_normalizer-3.3.0-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:2502dd2a736c879c0f0d3e2161e74d9907231e25d35794584b1ca5284e43f596"}, - {file = "charset_normalizer-3.3.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:23e8565ab7ff33218530bc817922fae827420f143479b753104ab801145b1d5b"}, - {file = "charset_normalizer-3.3.0-cp310-cp310-win32.whl", hash = "sha256:1872d01ac8c618a8da634e232f24793883d6e456a66593135aeafe3784b0848d"}, - {file = "charset_normalizer-3.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:557b21a44ceac6c6b9773bc65aa1b4cc3e248a5ad2f5b914b91579a32e22204d"}, - {file = "charset_normalizer-3.3.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:d7eff0f27edc5afa9e405f7165f85a6d782d308f3b6b9d96016c010597958e63"}, - {file = "charset_normalizer-3.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6a685067d05e46641d5d1623d7c7fdf15a357546cbb2f71b0ebde91b175ffc3e"}, - {file = "charset_normalizer-3.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0d3d5b7db9ed8a2b11a774db2bbea7ba1884430a205dbd54a32d61d7c2a190fa"}, - {file = "charset_normalizer-3.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2935ffc78db9645cb2086c2f8f4cfd23d9b73cc0dc80334bc30aac6f03f68f8c"}, - {file = "charset_normalizer-3.3.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9fe359b2e3a7729010060fbca442ca225280c16e923b37db0e955ac2a2b72a05"}, - {file = "charset_normalizer-3.3.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:380c4bde80bce25c6e4f77b19386f5ec9db230df9f2f2ac1e5ad7af2caa70459"}, - {file = "charset_normalizer-3.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f0d1e3732768fecb052d90d62b220af62ead5748ac51ef61e7b32c266cac9293"}, - {file = "charset_normalizer-3.3.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1b2919306936ac6efb3aed1fbf81039f7087ddadb3160882a57ee2ff74fd2382"}, - {file = "charset_normalizer-3.3.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f8888e31e3a85943743f8fc15e71536bda1c81d5aa36d014a3c0c44481d7db6e"}, - {file = "charset_normalizer-3.3.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:82eb849f085624f6a607538ee7b83a6d8126df6d2f7d3b319cb837b289123078"}, - {file = "charset_normalizer-3.3.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:7b8b8bf1189b3ba9b8de5c8db4d541b406611a71a955bbbd7385bbc45fcb786c"}, - {file = "charset_normalizer-3.3.0-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:5adf257bd58c1b8632046bbe43ee38c04e1038e9d37de9c57a94d6bd6ce5da34"}, - {file = "charset_normalizer-3.3.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c350354efb159b8767a6244c166f66e67506e06c8924ed74669b2c70bc8735b1"}, - {file = "charset_normalizer-3.3.0-cp311-cp311-win32.whl", hash = "sha256:02af06682e3590ab952599fbadac535ede5d60d78848e555aa58d0c0abbde786"}, - {file = "charset_normalizer-3.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:86d1f65ac145e2c9ed71d8ffb1905e9bba3a91ae29ba55b4c46ae6fc31d7c0d4"}, - {file = "charset_normalizer-3.3.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:3b447982ad46348c02cb90d230b75ac34e9886273df3a93eec0539308a6296d7"}, - {file = "charset_normalizer-3.3.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:abf0d9f45ea5fb95051c8bfe43cb40cda383772f7e5023a83cc481ca2604d74e"}, - {file = "charset_normalizer-3.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b09719a17a2301178fac4470d54b1680b18a5048b481cb8890e1ef820cb80455"}, - {file = "charset_normalizer-3.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b3d9b48ee6e3967b7901c052b670c7dda6deb812c309439adaffdec55c6d7b78"}, - {file = "charset_normalizer-3.3.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:edfe077ab09442d4ef3c52cb1f9dab89bff02f4524afc0acf2d46be17dc479f5"}, - {file = "charset_normalizer-3.3.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3debd1150027933210c2fc321527c2299118aa929c2f5a0a80ab6953e3bd1908"}, - {file = "charset_normalizer-3.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:86f63face3a527284f7bb8a9d4f78988e3c06823f7bea2bd6f0e0e9298ca0403"}, - {file = "charset_normalizer-3.3.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:24817cb02cbef7cd499f7c9a2735286b4782bd47a5b3516a0e84c50eab44b98e"}, - {file = "charset_normalizer-3.3.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:c71f16da1ed8949774ef79f4a0260d28b83b3a50c6576f8f4f0288d109777989"}, - {file = "charset_normalizer-3.3.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:9cf3126b85822c4e53aa28c7ec9869b924d6fcfb76e77a45c44b83d91afd74f9"}, - {file = "charset_normalizer-3.3.0-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:b3b2316b25644b23b54a6f6401074cebcecd1244c0b8e80111c9a3f1c8e83d65"}, - {file = "charset_normalizer-3.3.0-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:03680bb39035fbcffe828eae9c3f8afc0428c91d38e7d61aa992ef7a59fb120e"}, - {file = "charset_normalizer-3.3.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4cc152c5dd831641e995764f9f0b6589519f6f5123258ccaca8c6d34572fefa8"}, - {file = "charset_normalizer-3.3.0-cp312-cp312-win32.whl", hash = "sha256:b8f3307af845803fb0b060ab76cf6dd3a13adc15b6b451f54281d25911eb92df"}, - {file = "charset_normalizer-3.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:8eaf82f0eccd1505cf39a45a6bd0a8cf1c70dcfc30dba338207a969d91b965c0"}, - {file = "charset_normalizer-3.3.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:dc45229747b67ffc441b3de2f3ae5e62877a282ea828a5bdb67883c4ee4a8810"}, - {file = "charset_normalizer-3.3.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2f4a0033ce9a76e391542c182f0d48d084855b5fcba5010f707c8e8c34663d77"}, - {file = "charset_normalizer-3.3.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ada214c6fa40f8d800e575de6b91a40d0548139e5dc457d2ebb61470abf50186"}, - {file = "charset_normalizer-3.3.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b1121de0e9d6e6ca08289583d7491e7fcb18a439305b34a30b20d8215922d43c"}, - {file = "charset_normalizer-3.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1063da2c85b95f2d1a430f1c33b55c9c17ffaf5e612e10aeaad641c55a9e2b9d"}, - {file = "charset_normalizer-3.3.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:70f1d09c0d7748b73290b29219e854b3207aea922f839437870d8cc2168e31cc"}, - {file = "charset_normalizer-3.3.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:250c9eb0f4600361dd80d46112213dff2286231d92d3e52af1e5a6083d10cad9"}, - {file = "charset_normalizer-3.3.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:750b446b2ffce1739e8578576092179160f6d26bd5e23eb1789c4d64d5af7dc7"}, - {file = "charset_normalizer-3.3.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:fc52b79d83a3fe3a360902d3f5d79073a993597d48114c29485e9431092905d8"}, - {file = "charset_normalizer-3.3.0-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:588245972aca710b5b68802c8cad9edaa98589b1b42ad2b53accd6910dad3545"}, - {file = "charset_normalizer-3.3.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:e39c7eb31e3f5b1f88caff88bcff1b7f8334975b46f6ac6e9fc725d829bc35d4"}, - {file = "charset_normalizer-3.3.0-cp37-cp37m-win32.whl", hash = "sha256:abecce40dfebbfa6abf8e324e1860092eeca6f7375c8c4e655a8afb61af58f2c"}, - {file = "charset_normalizer-3.3.0-cp37-cp37m-win_amd64.whl", hash = "sha256:24a91a981f185721542a0b7c92e9054b7ab4fea0508a795846bc5b0abf8118d4"}, - {file = "charset_normalizer-3.3.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:67b8cc9574bb518ec76dc8e705d4c39ae78bb96237cb533edac149352c1f39fe"}, - {file = "charset_normalizer-3.3.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ac71b2977fb90c35d41c9453116e283fac47bb9096ad917b8819ca8b943abecd"}, - {file = "charset_normalizer-3.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3ae38d325b512f63f8da31f826e6cb6c367336f95e418137286ba362925c877e"}, - {file = "charset_normalizer-3.3.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:542da1178c1c6af8873e143910e2269add130a299c9106eef2594e15dae5e482"}, - {file = "charset_normalizer-3.3.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:30a85aed0b864ac88309b7d94be09f6046c834ef60762a8833b660139cfbad13"}, - {file = "charset_normalizer-3.3.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:aae32c93e0f64469f74ccc730a7cb21c7610af3a775157e50bbd38f816536b38"}, - {file = "charset_normalizer-3.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15b26ddf78d57f1d143bdf32e820fd8935d36abe8a25eb9ec0b5a71c82eb3895"}, - {file = "charset_normalizer-3.3.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7f5d10bae5d78e4551b7be7a9b29643a95aded9d0f602aa2ba584f0388e7a557"}, - {file = "charset_normalizer-3.3.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:249c6470a2b60935bafd1d1d13cd613f8cd8388d53461c67397ee6a0f5dce741"}, - {file = "charset_normalizer-3.3.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:c5a74c359b2d47d26cdbbc7845e9662d6b08a1e915eb015d044729e92e7050b7"}, - {file = "charset_normalizer-3.3.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:b5bcf60a228acae568e9911f410f9d9e0d43197d030ae5799e20dca8df588287"}, - {file = "charset_normalizer-3.3.0-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:187d18082694a29005ba2944c882344b6748d5be69e3a89bf3cc9d878e548d5a"}, - {file = "charset_normalizer-3.3.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:81bf654678e575403736b85ba3a7867e31c2c30a69bc57fe88e3ace52fb17b89"}, - {file = "charset_normalizer-3.3.0-cp38-cp38-win32.whl", hash = "sha256:85a32721ddde63c9df9ebb0d2045b9691d9750cb139c161c80e500d210f5e26e"}, - {file = "charset_normalizer-3.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:468d2a840567b13a590e67dd276c570f8de00ed767ecc611994c301d0f8c014f"}, - {file = "charset_normalizer-3.3.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e0fc42822278451bc13a2e8626cf2218ba570f27856b536e00cfa53099724828"}, - {file = "charset_normalizer-3.3.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:09c77f964f351a7369cc343911e0df63e762e42bac24cd7d18525961c81754f4"}, - {file = "charset_normalizer-3.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:12ebea541c44fdc88ccb794a13fe861cc5e35d64ed689513a5c03d05b53b7c82"}, - {file = "charset_normalizer-3.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:805dfea4ca10411a5296bcc75638017215a93ffb584c9e344731eef0dcfb026a"}, - {file = "charset_normalizer-3.3.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:96c2b49eb6a72c0e4991d62406e365d87067ca14c1a729a870d22354e6f68115"}, - {file = "charset_normalizer-3.3.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:aaf7b34c5bc56b38c931a54f7952f1ff0ae77a2e82496583b247f7c969eb1479"}, - {file = "charset_normalizer-3.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:619d1c96099be5823db34fe89e2582b336b5b074a7f47f819d6b3a57ff7bdb86"}, - {file = "charset_normalizer-3.3.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a0ac5e7015a5920cfce654c06618ec40c33e12801711da6b4258af59a8eff00a"}, - {file = "charset_normalizer-3.3.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:93aa7eef6ee71c629b51ef873991d6911b906d7312c6e8e99790c0f33c576f89"}, - {file = "charset_normalizer-3.3.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7966951325782121e67c81299a031f4c115615e68046f79b85856b86ebffc4cd"}, - {file = "charset_normalizer-3.3.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:02673e456dc5ab13659f85196c534dc596d4ef260e4d86e856c3b2773ce09843"}, - {file = "charset_normalizer-3.3.0-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:c2af80fb58f0f24b3f3adcb9148e6203fa67dd3f61c4af146ecad033024dde43"}, - {file = "charset_normalizer-3.3.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:153e7b6e724761741e0974fc4dcd406d35ba70b92bfe3fedcb497226c93b9da7"}, - {file = "charset_normalizer-3.3.0-cp39-cp39-win32.whl", hash = "sha256:d47ecf253780c90ee181d4d871cd655a789da937454045b17b5798da9393901a"}, - {file = "charset_normalizer-3.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:d97d85fa63f315a8bdaba2af9a6a686e0eceab77b3089af45133252618e70884"}, - {file = "charset_normalizer-3.3.0-py3-none-any.whl", hash = "sha256:e46cd37076971c1040fc8c41273a8b3e2c624ce4f2be3f5dfcb7a430c1d3acc2"}, + {file = "charset-normalizer-3.3.1.tar.gz", hash = "sha256:d9137a876020661972ca6eec0766d81aef8a5627df628b664b234b73396e727e"}, + {file = "charset_normalizer-3.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:8aee051c89e13565c6bd366813c386939f8e928af93c29fda4af86d25b73d8f8"}, + {file = "charset_normalizer-3.3.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:352a88c3df0d1fa886562384b86f9a9e27563d4704ee0e9d56ec6fcd270ea690"}, + {file = "charset_normalizer-3.3.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:223b4d54561c01048f657fa6ce41461d5ad8ff128b9678cfe8b2ecd951e3f8a2"}, + {file = "charset_normalizer-3.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f861d94c2a450b974b86093c6c027888627b8082f1299dfd5a4bae8e2292821"}, + {file = "charset_normalizer-3.3.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1171ef1fc5ab4693c5d151ae0fdad7f7349920eabbaca6271f95969fa0756c2d"}, + {file = "charset_normalizer-3.3.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28f512b9a33235545fbbdac6a330a510b63be278a50071a336afc1b78781b147"}, + {file = "charset_normalizer-3.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0e842112fe3f1a4ffcf64b06dc4c61a88441c2f02f373367f7b4c1aa9be2ad5"}, + {file = "charset_normalizer-3.3.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3f9bc2ce123637a60ebe819f9fccc614da1bcc05798bbbaf2dd4ec91f3e08846"}, + {file = "charset_normalizer-3.3.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:f194cce575e59ffe442c10a360182a986535fd90b57f7debfaa5c845c409ecc3"}, + {file = "charset_normalizer-3.3.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:9a74041ba0bfa9bc9b9bb2cd3238a6ab3b7618e759b41bd15b5f6ad958d17605"}, + {file = "charset_normalizer-3.3.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:b578cbe580e3b41ad17b1c428f382c814b32a6ce90f2d8e39e2e635d49e498d1"}, + {file = "charset_normalizer-3.3.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:6db3cfb9b4fcecb4390db154e75b49578c87a3b9979b40cdf90d7e4b945656e1"}, + {file = "charset_normalizer-3.3.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:debb633f3f7856f95ad957d9b9c781f8e2c6303ef21724ec94bea2ce2fcbd056"}, + {file = "charset_normalizer-3.3.1-cp310-cp310-win32.whl", hash = "sha256:87071618d3d8ec8b186d53cb6e66955ef2a0e4fa63ccd3709c0c90ac5a43520f"}, + {file = "charset_normalizer-3.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:e372d7dfd154009142631de2d316adad3cc1c36c32a38b16a4751ba78da2a397"}, + {file = "charset_normalizer-3.3.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ae4070f741f8d809075ef697877fd350ecf0b7c5837ed68738607ee0a2c572cf"}, + {file = "charset_normalizer-3.3.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:58e875eb7016fd014c0eea46c6fa92b87b62c0cb31b9feae25cbbe62c919f54d"}, + {file = "charset_normalizer-3.3.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:dbd95e300367aa0827496fe75a1766d198d34385a58f97683fe6e07f89ca3e3c"}, + {file = "charset_normalizer-3.3.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de0b4caa1c8a21394e8ce971997614a17648f94e1cd0640fbd6b4d14cab13a72"}, + {file = "charset_normalizer-3.3.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:985c7965f62f6f32bf432e2681173db41336a9c2611693247069288bcb0c7f8b"}, + {file = "charset_normalizer-3.3.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a15c1fe6d26e83fd2e5972425a772cca158eae58b05d4a25a4e474c221053e2d"}, + {file = "charset_normalizer-3.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ae55d592b02c4349525b6ed8f74c692509e5adffa842e582c0f861751701a673"}, + {file = "charset_normalizer-3.3.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:be4d9c2770044a59715eb57c1144dedea7c5d5ae80c68fb9959515037cde2008"}, + {file = "charset_normalizer-3.3.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:851cf693fb3aaef71031237cd68699dded198657ec1e76a76eb8be58c03a5d1f"}, + {file = "charset_normalizer-3.3.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:31bbaba7218904d2eabecf4feec0d07469284e952a27400f23b6628439439fa7"}, + {file = "charset_normalizer-3.3.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:871d045d6ccc181fd863a3cd66ee8e395523ebfbc57f85f91f035f50cee8e3d4"}, + {file = "charset_normalizer-3.3.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:501adc5eb6cd5f40a6f77fbd90e5ab915c8fd6e8c614af2db5561e16c600d6f3"}, + {file = "charset_normalizer-3.3.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f5fb672c396d826ca16a022ac04c9dce74e00a1c344f6ad1a0fdc1ba1f332213"}, + {file = "charset_normalizer-3.3.1-cp311-cp311-win32.whl", hash = "sha256:bb06098d019766ca16fc915ecaa455c1f1cd594204e7f840cd6258237b5079a8"}, + {file = "charset_normalizer-3.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:8af5a8917b8af42295e86b64903156b4f110a30dca5f3b5aedea123fbd638bff"}, + {file = "charset_normalizer-3.3.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:7ae8e5142dcc7a49168f4055255dbcced01dc1714a90a21f87448dc8d90617d1"}, + {file = "charset_normalizer-3.3.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5b70bab78accbc672f50e878a5b73ca692f45f5b5e25c8066d748c09405e6a55"}, + {file = "charset_normalizer-3.3.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5ceca5876032362ae73b83347be8b5dbd2d1faf3358deb38c9c88776779b2e2f"}, + {file = "charset_normalizer-3.3.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:34d95638ff3613849f473afc33f65c401a89f3b9528d0d213c7037c398a51296"}, + {file = "charset_normalizer-3.3.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9edbe6a5bf8b56a4a84533ba2b2f489d0046e755c29616ef8830f9e7d9cf5728"}, + {file = "charset_normalizer-3.3.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f6a02a3c7950cafaadcd46a226ad9e12fc9744652cc69f9e5534f98b47f3bbcf"}, + {file = "charset_normalizer-3.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10b8dd31e10f32410751b3430996f9807fc4d1587ca69772e2aa940a82ab571a"}, + {file = "charset_normalizer-3.3.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:edc0202099ea1d82844316604e17d2b175044f9bcb6b398aab781eba957224bd"}, + {file = "charset_normalizer-3.3.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b891a2f68e09c5ef989007fac11476ed33c5c9994449a4e2c3386529d703dc8b"}, + {file = "charset_normalizer-3.3.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:71ef3b9be10070360f289aea4838c784f8b851be3ba58cf796262b57775c2f14"}, + {file = "charset_normalizer-3.3.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:55602981b2dbf8184c098bc10287e8c245e351cd4fdcad050bd7199d5a8bf514"}, + {file = "charset_normalizer-3.3.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:46fb9970aa5eeca547d7aa0de5d4b124a288b42eaefac677bde805013c95725c"}, + {file = "charset_normalizer-3.3.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:520b7a142d2524f999447b3a0cf95115df81c4f33003c51a6ab637cbda9d0bf4"}, + {file = "charset_normalizer-3.3.1-cp312-cp312-win32.whl", hash = "sha256:8ec8ef42c6cd5856a7613dcd1eaf21e5573b2185263d87d27c8edcae33b62a61"}, + {file = "charset_normalizer-3.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:baec8148d6b8bd5cee1ae138ba658c71f5b03e0d69d5907703e3e1df96db5e41"}, + {file = "charset_normalizer-3.3.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:63a6f59e2d01310f754c270e4a257426fe5a591dc487f1983b3bbe793cf6bac6"}, + {file = "charset_normalizer-3.3.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d6bfc32a68bc0933819cfdfe45f9abc3cae3877e1d90aac7259d57e6e0f85b1"}, + {file = "charset_normalizer-3.3.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4f3100d86dcd03c03f7e9c3fdb23d92e32abbca07e7c13ebd7ddfbcb06f5991f"}, + {file = "charset_normalizer-3.3.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:39b70a6f88eebe239fa775190796d55a33cfb6d36b9ffdd37843f7c4c1b5dc67"}, + {file = "charset_normalizer-3.3.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e12f8ee80aa35e746230a2af83e81bd6b52daa92a8afaef4fea4a2ce9b9f4fa"}, + {file = "charset_normalizer-3.3.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7b6cefa579e1237ce198619b76eaa148b71894fb0d6bcf9024460f9bf30fd228"}, + {file = "charset_normalizer-3.3.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:61f1e3fb621f5420523abb71f5771a204b33c21d31e7d9d86881b2cffe92c47c"}, + {file = "charset_normalizer-3.3.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:4f6e2a839f83a6a76854d12dbebde50e4b1afa63e27761549d006fa53e9aa80e"}, + {file = "charset_normalizer-3.3.1-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:1ec937546cad86d0dce5396748bf392bb7b62a9eeb8c66efac60e947697f0e58"}, + {file = "charset_normalizer-3.3.1-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:82ca51ff0fc5b641a2d4e1cc8c5ff108699b7a56d7f3ad6f6da9dbb6f0145b48"}, + {file = "charset_normalizer-3.3.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:633968254f8d421e70f91c6ebe71ed0ab140220469cf87a9857e21c16687c034"}, + {file = "charset_normalizer-3.3.1-cp37-cp37m-win32.whl", hash = "sha256:c0c72d34e7de5604df0fde3644cc079feee5e55464967d10b24b1de268deceb9"}, + {file = "charset_normalizer-3.3.1-cp37-cp37m-win_amd64.whl", hash = "sha256:63accd11149c0f9a99e3bc095bbdb5a464862d77a7e309ad5938fbc8721235ae"}, + {file = "charset_normalizer-3.3.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5a3580a4fdc4ac05f9e53c57f965e3594b2f99796231380adb2baaab96e22761"}, + {file = "charset_normalizer-3.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2465aa50c9299d615d757c1c888bc6fef384b7c4aec81c05a0172b4400f98557"}, + {file = "charset_normalizer-3.3.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cb7cd68814308aade9d0c93c5bd2ade9f9441666f8ba5aa9c2d4b389cb5e2a45"}, + {file = "charset_normalizer-3.3.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91e43805ccafa0a91831f9cd5443aa34528c0c3f2cc48c4cb3d9a7721053874b"}, + {file = "charset_normalizer-3.3.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:854cc74367180beb327ab9d00f964f6d91da06450b0855cbbb09187bcdb02de5"}, + {file = "charset_normalizer-3.3.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c15070ebf11b8b7fd1bfff7217e9324963c82dbdf6182ff7050519e350e7ad9f"}, + {file = "charset_normalizer-3.3.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c4c99f98fc3a1835af8179dcc9013f93594d0670e2fa80c83aa36346ee763d2"}, + {file = "charset_normalizer-3.3.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3fb765362688821404ad6cf86772fc54993ec11577cd5a92ac44b4c2ba52155b"}, + {file = "charset_normalizer-3.3.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:dced27917823df984fe0c80a5c4ad75cf58df0fbfae890bc08004cd3888922a2"}, + {file = "charset_normalizer-3.3.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a66bcdf19c1a523e41b8e9d53d0cedbfbac2e93c649a2e9502cb26c014d0980c"}, + {file = "charset_normalizer-3.3.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:ecd26be9f112c4f96718290c10f4caea6cc798459a3a76636b817a0ed7874e42"}, + {file = "charset_normalizer-3.3.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:3f70fd716855cd3b855316b226a1ac8bdb3caf4f7ea96edcccc6f484217c9597"}, + {file = "charset_normalizer-3.3.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:17a866d61259c7de1bdadef418a37755050ddb4b922df8b356503234fff7932c"}, + {file = "charset_normalizer-3.3.1-cp38-cp38-win32.whl", hash = "sha256:548eefad783ed787b38cb6f9a574bd8664468cc76d1538215d510a3cd41406cb"}, + {file = "charset_normalizer-3.3.1-cp38-cp38-win_amd64.whl", hash = "sha256:45f053a0ece92c734d874861ffe6e3cc92150e32136dd59ab1fb070575189c97"}, + {file = "charset_normalizer-3.3.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:bc791ec3fd0c4309a753f95bb6c749ef0d8ea3aea91f07ee1cf06b7b02118f2f"}, + {file = "charset_normalizer-3.3.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0c8c61fb505c7dad1d251c284e712d4e0372cef3b067f7ddf82a7fa82e1e9a93"}, + {file = "charset_normalizer-3.3.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2c092be3885a1b7899cd85ce24acedc1034199d6fca1483fa2c3a35c86e43041"}, + {file = "charset_normalizer-3.3.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c2000c54c395d9e5e44c99dc7c20a64dc371f777faf8bae4919ad3e99ce5253e"}, + {file = "charset_normalizer-3.3.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4cb50a0335382aac15c31b61d8531bc9bb657cfd848b1d7158009472189f3d62"}, + {file = "charset_normalizer-3.3.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c30187840d36d0ba2893bc3271a36a517a717f9fd383a98e2697ee890a37c273"}, + {file = "charset_normalizer-3.3.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe81b35c33772e56f4b6cf62cf4aedc1762ef7162a31e6ac7fe5e40d0149eb67"}, + {file = "charset_normalizer-3.3.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d0bf89afcbcf4d1bb2652f6580e5e55a840fdf87384f6063c4a4f0c95e378656"}, + {file = "charset_normalizer-3.3.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:06cf46bdff72f58645434d467bf5228080801298fbba19fe268a01b4534467f5"}, + {file = "charset_normalizer-3.3.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:3c66df3f41abee950d6638adc7eac4730a306b022570f71dd0bd6ba53503ab57"}, + {file = "charset_normalizer-3.3.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:cd805513198304026bd379d1d516afbf6c3c13f4382134a2c526b8b854da1c2e"}, + {file = "charset_normalizer-3.3.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:9505dc359edb6a330efcd2be825fdb73ee3e628d9010597aa1aee5aa63442e97"}, + {file = "charset_normalizer-3.3.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:31445f38053476a0c4e6d12b047b08ced81e2c7c712e5a1ad97bc913256f91b2"}, + {file = "charset_normalizer-3.3.1-cp39-cp39-win32.whl", hash = "sha256:bd28b31730f0e982ace8663d108e01199098432a30a4c410d06fe08fdb9e93f4"}, + {file = "charset_normalizer-3.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:555fe186da0068d3354cdf4bbcbc609b0ecae4d04c921cc13e209eece7720727"}, + {file = "charset_normalizer-3.3.1-py3-none-any.whl", hash = "sha256:800561453acdecedaac137bf09cd719c7a440b6800ec182f077bb8e7025fb708"}, ] [[package]] @@ -741,13 +741,13 @@ files = [ [[package]] name = "fsspec" -version = "2023.9.2" +version = "2023.10.0" description = "File-system specification" optional = false python-versions = ">=3.8" files = [ - {file = "fsspec-2023.9.2-py3-none-any.whl", hash = "sha256:603dbc52c75b84da501b9b2ec8c11e1f61c25984c4a0dda1f129ef391fbfc9b4"}, - {file = "fsspec-2023.9.2.tar.gz", hash = "sha256:80bfb8c70cc27b2178cc62a935ecf242fc6e8c3fb801f9c571fc01b1e715ba7d"}, + {file = "fsspec-2023.10.0-py3-none-any.whl", hash = "sha256:346a8f024efeb749d2a5fca7ba8854474b1ff9af7c3faaf636a4548781136529"}, + {file = "fsspec-2023.10.0.tar.gz", hash = "sha256:330c66757591df346ad3091a53bd907e15348c2ba17d63fd54f5c39c4457d2a5"}, ] [package.dependencies] @@ -1017,6 +1017,17 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "joblib" +version = "1.3.2" +description = "Lightweight pipelining with Python functions" +optional = false +python-versions = ">=3.7" +files = [ + {file = "joblib-1.3.2-py3-none-any.whl", hash = "sha256:ef4331c65f239985f3f2220ecc87db222f08fd22097a3dd5698f693875f8cbb9"}, + {file = "joblib-1.3.2.tar.gz", hash = "sha256:92f865e621e17784e7955080b6d042489e3b8e294949cc44c6eac304f59772b1"}, +] + [[package]] name = "kiwisolver" version = "1.4.5" @@ -2165,44 +2176,30 @@ files = [ {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:b42169467c42b692c19cf539c38d4602069d8c1505e97b86387fcf7afb766e1d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:07238db9cbdf8fc1e9de2489a4f68474e70dffcb32232db7c08fa61ca0c7c462"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:fff3573c2db359f091e1589c3d7c5fc2f86f5bdb6f24252c2d8e539d4e45f412"}, - {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:840f0c7f194986a63d2c2465ca63af8ccbbc90ab1c6001b1978f05119b5e7334"}, - {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:024cfe1fc7c7f4e1aff4a81e718109e13409767e4f871443cbff3dba3578203d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-win32.whl", hash = "sha256:c69212f63169ec1cfc9bb44723bf2917cbbd8f6191a00ef3410f5a7fe300722d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-win_amd64.whl", hash = "sha256:cabddb8d8ead485e255fe80429f833172b4cadf99274db39abc080e068cbcc31"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:bef08cd86169d9eafb3ccb0a39edb11d8e25f3dae2b28f5c52fd997521133069"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:b16420e621d26fdfa949a8b4b47ade8810c56002f5389970db4ddda51dbff248"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:25c515e350e5b739842fc3228d662413ef28f295791af5e5110b543cf0b57d9b"}, - {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:46d378daaac94f454b3a0e3d8d78cafd78a026b1d71443f4966c696b48a6d899"}, - {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:09b055c05697b38ecacb7ac50bdab2240bfca1a0c4872b0fd309bb07dc9aa3a9"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-win32.whl", hash = "sha256:53a300ed9cea38cf5a2a9b069058137c2ca1ce658a874b79baceb8f892f915a7"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-win_amd64.whl", hash = "sha256:c2a72e9109ea74e511e29032f3b670835f8a59bbdc9ce692c5b4ed91ccf1eedb"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ebc06178e8821efc9692ea7544aa5644217358490145629914d8020042c24aa1"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:edaef1c1200c4b4cb914583150dcaa3bc30e592e907c01117c08b13a07255ec2"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d176b57452ab5b7028ac47e7b3cf644bcfdc8cacfecf7e71759f7f51a59e5c92"}, - {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3213ece08ea033eb159ac52ae052a4899b56ecc124bb80020d9bbceeb50258e9"}, - {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:aab7fd643f71d7946f2ee58cc88c9b7bfc97debd71dcc93e03e2d174628e7e2d"}, - {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-win32.whl", hash = "sha256:5c365d91c88390c8d0a8545df0b5857172824b1c604e867161e6b3d59a827eaa"}, - {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-win_amd64.whl", hash = "sha256:1758ce7d8e1a29d23de54a16ae867abd370f01b5a69e1a3ba75223eaa3ca1a1b"}, {file = "ruamel.yaml.clib-0.2.8-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:a5aa27bad2bb83670b71683aae140a1f52b0857a2deff56ad3f6c13a017a26ed"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c58ecd827313af6864893e7af0a3bb85fd529f862b6adbefe14643947cfe2942"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-macosx_12_0_arm64.whl", hash = "sha256:f481f16baec5290e45aebdc2a5168ebc6d35189ae6fea7a58787613a25f6e875"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:7f67a1ee819dc4562d444bbafb135832b0b909f81cc90f7aa00260968c9ca1b3"}, - {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:4ecbf9c3e19f9562c7fdd462e8d18dd902a47ca046a2e64dba80699f0b6c09b7"}, - {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:87ea5ff66d8064301a154b3933ae406b0863402a799b16e4a1d24d9fbbcbe0d3"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-win32.whl", hash = "sha256:75e1ed13e1f9de23c5607fe6bd1aeaae21e523b32d83bb33918245361e9cc51b"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-win_amd64.whl", hash = "sha256:3f215c5daf6a9d7bbed4a0a4f760f3113b10e82ff4c5c44bec20a68c8014f675"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1b617618914cb00bf5c34d4357c37aa15183fa229b24767259657746c9077615"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:a6a9ffd280b71ad062eae53ac1659ad86a17f59a0fdc7699fd9be40525153337"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:700e4ebb569e59e16a976857c8798aee258dceac7c7d6b50cab63e080058df91"}, - {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:e2b4c44b60eadec492926a7270abb100ef9f72798e18743939bdbf037aab8c28"}, - {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e79e5db08739731b0ce4850bed599235d601701d5694c36570a99a0c5ca41a9d"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-win32.whl", hash = "sha256:955eae71ac26c1ab35924203fda6220f84dce57d6d7884f189743e2abe3a9fbe"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-win_amd64.whl", hash = "sha256:56f4252222c067b4ce51ae12cbac231bce32aee1d33fbfc9d17e5b8d6966c312"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:03d1162b6d1df1caa3a4bd27aa51ce17c9afc2046c31b0ad60a0a96ec22f8001"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:bba64af9fa9cebe325a62fa398760f5c7206b215201b0ec825005f1b18b9bccf"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:da09ad1c359a728e112d60116f626cc9f29730ff3e0e7db72b9a2dbc2e4beed5"}, - {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:184565012b60405d93838167f425713180b949e9d8dd0bbc7b49f074407c5a8b"}, - {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a75879bacf2c987c003368cf14bed0ffe99e8e85acfa6c0bfffc21a090f16880"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-win32.whl", hash = "sha256:84b554931e932c46f94ab306913ad7e11bba988104c5cff26d90d03f68258cd5"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-win_amd64.whl", hash = "sha256:25ac8c08322002b06fa1d49d1646181f0b2c72f5cbc15a85e80b4c30a544bb15"}, {file = "ruamel.yaml.clib-0.2.8.tar.gz", hash = "sha256:beb2e0404003de9a4cab9753a8805a8fe9320ee6673136ed7f04255fe60bb512"}, @@ -2353,6 +2350,53 @@ tensorflow = ["safetensors[numpy]", "tensorflow (>=2.11.0)"] testing = ["h5py (>=3.7.0)", "huggingface_hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools_rust (>=1.5.2)"] torch = ["safetensors[numpy]", "torch (>=1.10)"] +[[package]] +name = "scikit-learn" +version = "1.3.2" +description = "A set of python modules for machine learning and data mining" +optional = false +python-versions = ">=3.8" +files = [ + {file = "scikit-learn-1.3.2.tar.gz", hash = "sha256:a2f54c76accc15a34bfb9066e6c7a56c1e7235dda5762b990792330b52ccfb05"}, + {file = "scikit_learn-1.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e326c0eb5cf4d6ba40f93776a20e9a7a69524c4db0757e7ce24ba222471ee8a1"}, + {file = "scikit_learn-1.3.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:535805c2a01ccb40ca4ab7d081d771aea67e535153e35a1fd99418fcedd1648a"}, + {file = "scikit_learn-1.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1215e5e58e9880b554b01187b8c9390bf4dc4692eedeaf542d3273f4785e342c"}, + {file = "scikit_learn-1.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ee107923a623b9f517754ea2f69ea3b62fc898a3641766cb7deb2f2ce450161"}, + {file = "scikit_learn-1.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:35a22e8015048c628ad099da9df5ab3004cdbf81edc75b396fd0cff8699ac58c"}, + {file = "scikit_learn-1.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6fb6bc98f234fda43163ddbe36df8bcde1d13ee176c6dc9b92bb7d3fc842eb66"}, + {file = "scikit_learn-1.3.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:18424efee518a1cde7b0b53a422cde2f6625197de6af36da0b57ec502f126157"}, + {file = "scikit_learn-1.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3271552a5eb16f208a6f7f617b8cc6d1f137b52c8a1ef8edf547db0259b2c9fb"}, + {file = "scikit_learn-1.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc4144a5004a676d5022b798d9e573b05139e77f271253a4703eed295bde0433"}, + {file = "scikit_learn-1.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:67f37d708f042a9b8d59551cf94d30431e01374e00dc2645fa186059c6c5d78b"}, + {file = "scikit_learn-1.3.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8db94cd8a2e038b37a80a04df8783e09caac77cbe052146432e67800e430c028"}, + {file = "scikit_learn-1.3.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:61a6efd384258789aa89415a410dcdb39a50e19d3d8410bd29be365bcdd512d5"}, + {file = "scikit_learn-1.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb06f8dce3f5ddc5dee1715a9b9f19f20d295bed8e3cd4fa51e1d050347de525"}, + {file = "scikit_learn-1.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5b2de18d86f630d68fe1f87af690d451388bb186480afc719e5f770590c2ef6c"}, + {file = "scikit_learn-1.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:0402638c9a7c219ee52c94cbebc8fcb5eb9fe9c773717965c1f4185588ad3107"}, + {file = "scikit_learn-1.3.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:a19f90f95ba93c1a7f7924906d0576a84da7f3b2282ac3bfb7a08a32801add93"}, + {file = "scikit_learn-1.3.2-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:b8692e395a03a60cd927125eef3a8e3424d86dde9b2370d544f0ea35f78a8073"}, + {file = "scikit_learn-1.3.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:15e1e94cc23d04d39da797ee34236ce2375ddea158b10bee3c343647d615581d"}, + {file = "scikit_learn-1.3.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:785a2213086b7b1abf037aeadbbd6d67159feb3e30263434139c98425e3dcfcf"}, + {file = "scikit_learn-1.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:64381066f8aa63c2710e6b56edc9f0894cc7bf59bd71b8ce5613a4559b6145e0"}, + {file = "scikit_learn-1.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6c43290337f7a4b969d207e620658372ba3c1ffb611f8bc2b6f031dc5c6d1d03"}, + {file = "scikit_learn-1.3.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:dc9002fc200bed597d5d34e90c752b74df516d592db162f756cc52836b38fe0e"}, + {file = "scikit_learn-1.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d08ada33e955c54355d909b9c06a4789a729977f165b8bae6f225ff0a60ec4a"}, + {file = "scikit_learn-1.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:763f0ae4b79b0ff9cca0bf3716bcc9915bdacff3cebea15ec79652d1cc4fa5c9"}, + {file = "scikit_learn-1.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:ed932ea780517b00dae7431e031faae6b49b20eb6950918eb83bd043237950e0"}, +] + +[package.dependencies] +joblib = ">=1.1.1" +numpy = ">=1.17.3,<2.0" +scipy = ">=1.5.0" +threadpoolctl = ">=2.0.0" + +[package.extras] +benchmark = ["matplotlib (>=3.1.3)", "memory-profiler (>=0.57.0)", "pandas (>=1.0.5)"] +docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.1.3)", "memory-profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.0.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.16.2)", "seaborn (>=0.9.0)", "sphinx (>=6.0.0)", "sphinx-copybutton (>=0.5.2)", "sphinx-gallery (>=0.10.1)", "sphinx-prompt (>=1.3.0)", "sphinxext-opengraph (>=0.4.2)"] +examples = ["matplotlib (>=3.1.3)", "pandas (>=1.0.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.16.2)", "seaborn (>=0.9.0)"] +tests = ["black (>=23.3.0)", "matplotlib (>=3.1.3)", "mypy (>=1.3)", "numpydoc (>=1.2.0)", "pandas (>=1.0.5)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.0.272)", "scikit-image (>=0.16.2)"] + [[package]] name = "scipy" version = "1.11.3" @@ -2728,6 +2772,17 @@ files = [ {file = "tensorboard_data_server-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:255c02b7f5b03dd5c0a88c928e563441ff39e1d4b4a234cdbe09f016e53d9594"}, ] +[[package]] +name = "threadpoolctl" +version = "3.2.0" +description = "threadpoolctl" +optional = false +python-versions = ">=3.8" +files = [ + {file = "threadpoolctl-3.2.0-py3-none-any.whl", hash = "sha256:2b7818516e423bdaebb97c723f86a7c6b0a83d3f3b0970328d66f4d9104dc032"}, + {file = "threadpoolctl-3.2.0.tar.gz", hash = "sha256:c96a0ba3bdddeaca37dc4cc7344aafad41cdb8c313f74fdfe387a867bba93355"}, +] + [[package]] name = "timm" version = "0.9.8" @@ -3141,4 +3196,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.12" -content-hash = "26a8e38fe8e20d71df208d7f3d1bfbba68c82b685599fdc06aead21b89ee4a1d" +content-hash = "4d1f73c72a3be267829f0d244fb42fc4d4e073b68a70b860160bfc22bc66b184" diff --git a/pyproject.toml b/pyproject.toml index 6e0e461e..a0033930 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ torchinfo = ">=1.7.1" scipy = "^1.10.0" huggingface-hub = "^0.14.1" pandas = "^2.0.3" +scikit-learn = "^1.3.2" [tool.poetry.group.dev] optional = true diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index 1a9e9bd3..e43bfed1 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -67,7 +67,7 @@ def __new__( cls, in_features: int, out_features: int, - loss: nn.Module, + loss: Type[nn.Module], optimization_procedure: Any, baseline_type: str = "single", dist_estimation: int = 1, diff --git a/tests/_dummies/datamodule.py b/tests/_dummies/datamodule.py index b2dc00ba..99a5313c 100644 --- a/tests/_dummies/datamodule.py +++ b/tests/_dummies/datamodule.py @@ -1,15 +1,18 @@ from argparse import ArgumentParser from pathlib import Path from typing import Any, List, Optional, Union +from numpy.typing import ArrayLike +import numpy as np import torchvision.transforms as T -from pytorch_lightning import LightningDataModule -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader + +from torch_uncertainty.datamodules.abstract import AbstractDataModule from .dataset import DummyClassificationDataset, DummyRegressionDataset -class DummyClassificationDataModule(LightningDataModule): +class DummyClassificationDataModule(AbstractDataModule): num_channels = 1 image_size: int = 4 training_task = "classification" @@ -25,17 +28,16 @@ def __init__( persistent_workers: bool = True, **kwargs, ) -> None: - super().__init__() - - root = Path(root) + super().__init__( + root=root, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) - self.root: Path = root self.evaluate_ood = evaluate_ood - self.batch_size = batch_size self.num_classes = num_classes - self.num_workers = num_workers - self.pin_memory = pin_memory - self.persistent_workers = persistent_workers self.dataset = DummyClassificationDataset self.ood_dataset = DummyClassificationDataset @@ -78,29 +80,17 @@ def setup(self, stage: Optional[str] = None) -> None: transform=self.transform_test, ) - def train_dataloader(self) -> DataLoader: - return self._data_loader(self.train, shuffle=True) - - def val_dataloader(self) -> DataLoader: - return self._data_loader(self.val) - def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]: dataloader = [self._data_loader(self.test)] if self.evaluate_ood: dataloader.append(self._data_loader(self.ood)) return dataloader - def _data_loader( - self, dataset: Dataset, shuffle: bool = False - ) -> DataLoader: - return DataLoader( - dataset, - batch_size=self.batch_size, - shuffle=shuffle, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - persistent_workers=self.persistent_workers, - ) + def _get_train_data(self) -> ArrayLike: + return self.train.data + + def _get_train_targets(self) -> ArrayLike: + return np.array(self.train.targets) @classmethod def add_argparse_args( @@ -108,15 +98,12 @@ def add_argparse_args( parent_parser: ArgumentParser, **kwargs: Any, ) -> ArgumentParser: - p = parent_parser.add_argument_group("datamodule") - p.add_argument("--root", type=str, default="./data/") - p.add_argument("--batch_size", type=int, default=2) - p.add_argument("--num_workers", type=int, default=1) + p = super().add_argparse_args(parent_parser) p.add_argument("--evaluate_ood", action="store_true") return parent_parser -class DummyRegressionDataModule(LightningDataModule): +class DummyRegressionDataModule(AbstractDataModule): in_features = 4 training_task = "regression" @@ -131,17 +118,16 @@ def __init__( persistent_workers: bool = True, **kwargs, ) -> None: - super().__init__() + super().__init__( + root=root, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) - if isinstance(root, str): - root = Path(root) - self.root: Path = root self.evaluate_ood = evaluate_ood - self.batch_size = batch_size self.out_features = out_features - self.num_workers = num_workers - self.pin_memory = pin_memory - self.persistent_workers = persistent_workers self.dataset = DummyRegressionDataset self.ood_dataset = DummyRegressionDataset @@ -177,39 +163,18 @@ def setup(self, stage: Optional[str] = None) -> None: transform=self.transform_test, ) - def train_dataloader(self) -> DataLoader: - return self._data_loader(self.train, shuffle=True) - - def val_dataloader(self) -> DataLoader: - return self._data_loader(self.val) - def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]: dataloader = [self._data_loader(self.test)] if self.evaluate_ood: dataloader.append(self._data_loader(self.ood)) return dataloader - def _data_loader( - self, dataset: Dataset, shuffle: bool = False - ) -> DataLoader: - return DataLoader( - dataset, - batch_size=self.batch_size, - shuffle=shuffle, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - persistent_workers=self.persistent_workers, - ) - @classmethod def add_argparse_args( cls, parent_parser: ArgumentParser, **kwargs: Any, ) -> ArgumentParser: - p = parent_parser.add_argument_group("datamodule") - p.add_argument("--root", type=str, default="./data/") - p.add_argument("--batch_size", type=int, default=2) - p.add_argument("--num_workers", type=int, default=1) + p = super().add_argparse_args(parent_parser) p.add_argument("--evaluate_ood", action="store_true") return parent_parser diff --git a/tests/_dummies/dataset.py b/tests/_dummies/dataset.py index ee830367..fa26669a 100644 --- a/tests/_dummies/dataset.py +++ b/tests/_dummies/dataset.py @@ -1,12 +1,11 @@ from pathlib import Path from typing import Any, Callable, Tuple +import numpy as np import torch import torch.utils.data as data from PIL import Image -import numpy as np - class DummyClassificationDataset(data.Dataset): def __init__( @@ -69,6 +68,9 @@ def __init__( num_images // (num_classes) + 1 )[:num_images] + self.samples = self.data # for compatibility with TinyImagenet + self.label_data = self.targets + def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Args: diff --git a/tests/_dummies/model.py b/tests/_dummies/model.py index e048438c..07900527 100644 --- a/tests/_dummies/model.py +++ b/tests/_dummies/model.py @@ -22,6 +22,9 @@ def __init__( self.num_estimators = num_estimators + def feats_forward(self, x: Tensor) -> Tensor: + return self.forward(x) + def forward(self, x: Tensor) -> Tensor: out = self.linear( torch.ones( diff --git a/tests/datamodules/test_abstract_datamodule.py b/tests/datamodules/test_abstract_datamodule.py new file mode 100644 index 00000000..ea6f3807 --- /dev/null +++ b/tests/datamodules/test_abstract_datamodule.py @@ -0,0 +1,56 @@ +from pathlib import Path +import pytest + +from torch_uncertainty.datamodules.abstract import ( + AbstractDataModule, + CrossValDataModule, +) + +from .._dummies.dataset import DummyClassificationDataset + + +class TestAbstractDataModule: + """Testing the AbstractDataModule class.""" + + def test_errors(self): + dm = AbstractDataModule("root", 128, 4, True, True) + with pytest.raises(NotImplementedError): + dm.setup() + dm._get_train_data() + dm._get_train_targets() + + +class TestCrossValDataModule: + """Testing the CrossValDataModule class.""" + + def test_cv_main(self): + dm = AbstractDataModule("root", 128, 4, True, True) + ds = DummyClassificationDataset(Path("root")) + dm.train = ds + dm.val = ds + dm.test = ds + cv_dm = CrossValDataModule("root", [0], [1], dm, 128, 4, True, True) + + cv_dm.setup() + cv_dm.setup("test") + + # test abstract methods + cv_dm.get_train_set() + cv_dm.get_val_set() + cv_dm.get_test_set() + + cv_dm.train_dataloader() + cv_dm.val_dataloader() + cv_dm.test_dataloader() + + def test_errors(self): + dm = AbstractDataModule("root", 128, 4, True, True) + ds = DummyClassificationDataset(Path("root")) + dm.train = ds + dm.val = ds + dm.test = ds + cv_dm = CrossValDataModule("root", [0], [1], dm, 128, 4, True, True) + with pytest.raises(NotImplementedError): + cv_dm.setup() + cv_dm._get_train_data() + cv_dm._get_train_targets() diff --git a/tests/datamodules/test_cifar100_datamodule.py b/tests/datamodules/test_cifar100_datamodule.py index 480b453d..0fa48f0c 100644 --- a/tests/datamodules/test_cifar100_datamodule.py +++ b/tests/datamodules/test_cifar100_datamodule.py @@ -75,3 +75,35 @@ def test_cifar100(self): args.auto_augment = "rand-m9-n2-mstd0.5" dm = CIFAR100DataModule(**vars(args)) + + def test_cifar100_cv(self): + parser = ArgumentParser() + parser = CIFAR100DataModule.add_argparse_args(parser) + + # Simulate that cutout is set to 8 + args = parser.parse_args("") + + dm = CIFAR100DataModule(**vars(args)) + dm.dataset = ( + lambda root, train, download, transform: DummyClassificationDataset( + root, + train=train, + download=download, + transform=transform, + num_images=20, + ) + ) + dm.make_cross_val_splits(2, 1) + + args.val_split = 0.1 + dm = CIFAR100DataModule(**vars(args)) + dm.dataset = ( + lambda root, train, download, transform: DummyClassificationDataset( + root, + train=train, + download=download, + transform=transform, + num_images=20, + ) + ) + dm.make_cross_val_splits(2, 1) diff --git a/tests/datamodules/test_cifar10_datamodule.py b/tests/datamodules/test_cifar10_datamodule.py index c7e38778..53376dea 100644 --- a/tests/datamodules/test_cifar10_datamodule.py +++ b/tests/datamodules/test_cifar10_datamodule.py @@ -12,7 +12,7 @@ class TestCIFAR10DataModule: """Testing the CIFAR10DataModule datamodule class.""" - def test_CIFAR10_cutout(self): + def test_CIFAR10_main(self): parser = ArgumentParser() parser = CIFAR10DataModule.add_argparse_args(parser) @@ -32,6 +32,14 @@ def test_CIFAR10_cutout(self): dm.setup() dm.setup("test") + with pytest.raises(ValueError): + dm.setup("xxx") + + # test abstract methods + dm.get_train_set() + dm.get_val_set() + dm.get_test_set() + dm.train_dataloader() dm.val_dataloader() dm.test_dataloader() @@ -47,6 +55,11 @@ def test_CIFAR10_cutout(self): with pytest.raises(ValueError): dm.setup() + args.test_alt = "h" + dm = CIFAR10DataModule(**vars(args)) + dm.dataset = DummyClassificationDataset + dm.setup("test") + args.test_alt = None args.num_dataloaders = 2 args.val_split = 0.1 @@ -65,3 +78,35 @@ def test_CIFAR10_cutout(self): args.cutout = None args.auto_augment = "rand-m9-n2-mstd0.5" dm = CIFAR10DataModule(**vars(args)) + + def test_cifar10_cv(self): + parser = ArgumentParser() + parser = CIFAR10DataModule.add_argparse_args(parser) + + # Simulate that cutout is set to 8 + args = parser.parse_args("") + + dm = CIFAR10DataModule(**vars(args)) + dm.dataset = ( + lambda root, train, download, transform: DummyClassificationDataset( + root, + train=train, + download=download, + transform=transform, + num_images=20, + ) + ) + dm.make_cross_val_splits(2, 1) + + args.val_split = 0.1 + dm = CIFAR10DataModule(**vars(args)) + dm.dataset = ( + lambda root, train, download, transform: DummyClassificationDataset( + root, + train=train, + download=download, + transform=transform, + num_images=20, + ) + ) + dm.make_cross_val_splits(2, 1) diff --git a/tests/datamodules/test_tiny_imagenet_datamodule.py b/tests/datamodules/test_tiny_imagenet_datamodule.py index 7a5e7144..aa36fdbf 100644 --- a/tests/datamodules/test_tiny_imagenet_datamodule.py +++ b/tests/datamodules/test_tiny_imagenet_datamodule.py @@ -11,7 +11,7 @@ class TestTinyImageNetDataModule: """Testing the TinyImageNetDataModule datamodule class.""" - def test_imagenet(self): + def test_tiny_imagenet(self): parser = ArgumentParser() parser = TinyImageNetDataModule.add_argparse_args(parser) @@ -46,3 +46,23 @@ def test_imagenet(self): dm.prepare_data() dm.setup("test") dm.test_dataloader() + + def test_tiny_imagenet_cv(self): + parser = ArgumentParser() + parser = TinyImageNetDataModule.add_argparse_args(parser) + + # Simulate that cutout is set to 8 + args = parser.parse_args("") + + dm = TinyImageNetDataModule(**vars(args)) + dm.dataset = lambda root, split, transform: DummyClassificationDataset( + root, split=split, transform=transform, num_images=20 + ) + dm.make_cross_val_splits(2, 1) + + args.val_split = 0.1 + dm = TinyImageNetDataModule(**vars(args)) + dm.dataset = lambda root, split, transform: DummyClassificationDataset( + root, split=split, transform=transform, num_images=20 + ) + dm.make_cross_val_splits(2, 1) diff --git a/tests/models/test_resnets.py b/tests/models/test_resnets.py index 92f8390a..49a8c462 100644 --- a/tests/models/test_resnets.py +++ b/tests/models/test_resnets.py @@ -38,6 +38,7 @@ def test_main(self): model = resnet50(1, 10, 1) with torch.no_grad(): model(torch.randn(2, 1, 32, 32)) + model.feats_forward(torch.randn(2, 1, 32, 32)) def test_mc_dropout(self): resnet34(1, 10, 1, num_estimators=5) diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index bc8574ec..62ed2968 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -16,6 +16,7 @@ from .._dummies import ( DummyClassificationBaseline, DummyClassificationDataModule, + DummyClassificationDataset, ) @@ -40,7 +41,7 @@ def test_cli_main_dummy_binary(self): baseline_type="single", **vars(args), ) - cli_main(model, dm, root, "dummy", args) + cli_main(model, dm, root, "logs/dummy", args) with ArgvContext("file.py", "--logits"): args = init_args( @@ -58,7 +59,7 @@ def test_cli_main_dummy_binary(self): baseline_type="single", **vars(args), ) - cli_main(model, dm, root, "dummy", args) + cli_main(model, dm, root, "logs/dummy", args) def test_cli_main_dummy_ood(self): root = Path(__file__).parent.absolute().parents[0] @@ -83,7 +84,7 @@ def test_cli_main_dummy_ood(self): baseline_type="single", **vars(args), ) - cli_main(model, dm, root, "dummy", args) + cli_main(model, dm, root, "logs/dummy", args) with ArgvContext( "file.py", @@ -105,10 +106,16 @@ def test_cli_main_dummy_ood(self): baseline_type="single", **vars(args), ) - cli_main(model, dm, root, "dummy", args) + cli_main(model, dm, root, "logs/dummy", args) with ArgvContext( - "file.py", "--evaluate_ood", "--entropy", "--cutmix", "0.5" + "file.py", + "--evaluate_ood", + "--entropy", + "--cutmix_alpha", + "0.5", + "--mixtype", + "timm", ): args = init_args( DummyClassificationBaseline, DummyClassificationDataModule @@ -126,7 +133,107 @@ def test_cli_main_dummy_ood(self): **vars(args), ) with pytest.raises(NotImplementedError): - cli_main(model, dm, root, "dummy", args) + cli_main(model, dm, root, "logs/dummy", args) + + def test_cli_main_dummy_mixup_ts_cv(self): + root = Path(__file__).parent.absolute().parents[0] + with ArgvContext( + "file.py", + "--mixtype", + "kernel_warping", + "--mixup_alpha", + "1.", + "--dist_sim", + "inp", + "--val_temp_scaling", + "--use_cv", + ): + args = init_args( + DummyClassificationBaseline, DummyClassificationDataModule + ) + + args.root = str(root / "data") + dm = DummyClassificationDataModule(num_classes=10, **vars(args)) + dm.dataset = ( + lambda root, + num_channels, + num_classes, + image_size, + transform: DummyClassificationDataset( + root, + num_channels=num_channels, + num_classes=num_classes, + image_size=image_size, + transform=transform, + num_images=20, + ) + ) + + list_dm = dm.make_cross_val_splits(2, 1) + list_model = [] + for i in range(len(list_dm)): + list_model.append( + DummyClassificationBaseline( + num_classes=list_dm[i].dm.num_classes, + in_channels=list_dm[i].dm.num_channels, + loss=nn.CrossEntropyLoss, + optimization_procedure=optim_cifar10_resnet18, + baseline_type="single", + calibration_set=dm.get_val_set, + **vars(args), + ) + ) + + cli_main(list_model, list_dm, root, "logs/dummy", args) + + with ArgvContext( + "file.py", + "--mixtype", + "kernel_warping", + "--mixup_alpha", + "1.", + "--dist_sim", + "emb", + "--val_temp_scaling", + "--use_cv", + ): + args = init_args( + DummyClassificationBaseline, DummyClassificationDataModule + ) + + args.root = str(root / "data") + dm = DummyClassificationDataModule(num_classes=10, **vars(args)) + dm.dataset = ( + lambda root, + num_channels, + num_classes, + image_size, + transform: DummyClassificationDataset( + root, + num_channels=num_channels, + num_classes=num_classes, + image_size=image_size, + transform=transform, + num_images=20, + ) + ) + + list_dm = dm.make_cross_val_splits(2, 1) + list_model = [] + for i in range(len(list_dm)): + list_model.append( + DummyClassificationBaseline( + num_classes=list_dm[i].dm.num_classes, + in_channels=list_dm[i].dm.num_channels, + loss=nn.CrossEntropyLoss, + optimization_procedure=optim_cifar10_resnet18, + baseline_type="single", + calibration_set=dm.get_val_set, + **vars(args), + ) + ) + + cli_main(list_model, list_dm, root, "logs/dummy", args) def test_classification_failures(self): with pytest.raises(ValueError): @@ -166,7 +273,7 @@ def test_cli_main_dummy_binary(self): **vars(args), ) - cli_main(model, dm, root, "dummy", args) + cli_main(model, dm, root, "logs/dummy", args) with ArgvContext("file.py", "--mutual_information"): args = init_args( @@ -186,7 +293,7 @@ def test_cli_main_dummy_binary(self): **vars(args), ) - cli_main(model, dm, root, "dummy", args) + cli_main(model, dm, root, "logs/dummy", args) def test_cli_main_dummy_ood(self): root = Path(__file__).parent.absolute().parents[0] @@ -208,7 +315,7 @@ def test_cli_main_dummy_ood(self): **vars(args), ) - cli_main(model, dm, root, "dummy", args) + cli_main(model, dm, root, "logs/dummy", args) with ArgvContext("file.py", "--evaluate_ood", "--entropy"): args = init_args( @@ -228,7 +335,7 @@ def test_cli_main_dummy_ood(self): **vars(args), ) - cli_main(model, dm, root, "dummy", args) + cli_main(model, dm, root, "logs/dummy", args) with ArgvContext("file.py", "--evaluate_ood", "--variation_ratio"): args = init_args( @@ -248,7 +355,7 @@ def test_cli_main_dummy_ood(self): **vars(args), ) - cli_main(model, dm, root, "dummy", args) + cli_main(model, dm, root, "logs/dummy", args) def test_classification_failures(self): with pytest.raises(ValueError): diff --git a/tests/routines/test_regression.py b/tests/routines/test_regression.py index 956cece2..64c6e354 100644 --- a/tests/routines/test_regression.py +++ b/tests/routines/test_regression.py @@ -6,7 +6,7 @@ from torch import nn from torch_uncertainty import cli_main, init_args -from torch_uncertainty.losses import NIGLoss, BetaNLL +from torch_uncertainty.losses import BetaNLL, NIGLoss from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18 from .._dummies import DummyRegressionBaseline, DummyRegressionDataModule @@ -34,7 +34,7 @@ def test_cli_main_dummy_dist(self): **vars(args), ) - cli_main(model, dm, root, "dummy", args) + cli_main(model, dm, root, "logs/dummy", args) def test_cli_main_dummy_dist_der(self): root = Path(__file__).parent.absolute().parents[0] @@ -60,7 +60,7 @@ def test_cli_main_dummy_dist_der(self): **vars(args), ) - cli_main(model, dm, root, "dummy_der", args) + cli_main(model, dm, root, "logs/dummy_der", args) def test_cli_main_dummy_dist_betanll(self): root = Path(__file__).parent.absolute().parents[0] @@ -86,7 +86,7 @@ def test_cli_main_dummy_dist_betanll(self): **vars(args), ) - cli_main(model, dm, root, "dummy_betanll", args) + cli_main(model, dm, root, "logs/dummy_betanll", args) def test_cli_main_dummy(self): root = Path(__file__).parent.absolute().parents[0] @@ -106,7 +106,7 @@ def test_cli_main_dummy(self): **vars(args), ) - cli_main(model, dm, root, "dummy", args) + cli_main(model, dm, root, "logs/dummy", args) def test_regression_failures(self): with pytest.raises(ValueError): @@ -158,4 +158,4 @@ def test_cli_main_dummy(self): **vars(args), ) - cli_main(model, dm, root, "dummy", args) + cli_main(model, dm, root, "logs/dummy", args) diff --git a/tests/test_cli.py b/tests/test_cli.py index 73d6d74f..9fd24dea 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,4 +1,5 @@ import sys +import os from pathlib import Path import pytest @@ -15,6 +16,9 @@ optim_cifar10_wideresnet, optim_regression, ) +from torch_uncertainty.utils.misc import csv_writer + +from ._dummies.dataset import DummyClassificationDataset class TestCLI: @@ -41,7 +45,21 @@ def test_cli_main_resnet(self): **vars(args), ) - cli_main(model, dm, root, "std", args) + results = cli_main(model, dm, root, "std", args) + results_path = root / "tests" / "logs" + if not os.path.exists(results_path): + os.makedirs(results_path) + for dict_result in results: + csv_writer( + results_path / "results.csv", + dict_result, + ) + # Test if file already exists + for dict_result in results: + csv_writer( + results_path / "results.csv", + dict_result, + ) def test_cli_main_other_arguments(self): root = Path(__file__).parent.absolute().parents[0] @@ -136,13 +154,16 @@ def test_cli_main_mlp(self): cli_main(model, dm, root, "std", args) + args.test = True + cli_main(model, dm, root, "std", args) + def test_cli_other_training_task(self): root = Path(__file__).parent.absolute().parents[0] with ArgvContext("file.py"): args = init_args(MLP, UCIDataModule) # datamodule - args.root = root / "/data" + args.root = root / "data" dm = UCIDataModule( dataset_name="kin8nm", input_shape=(1, 5), **vars(args) ) @@ -163,6 +184,205 @@ def test_cli_other_training_task(self): with pytest.raises(ValueError): cli_main(model, dm, root, "std", args) + def test_cli_cv_ts(self): + root = Path(__file__).parent.absolute().parents[0] + with ArgvContext("file.py", "--use_cv", "--channels_last"): + args = init_args(ResNet, CIFAR10DataModule) + + # datamodule + args.root = str(root / "data") + dm = CIFAR10DataModule(**vars(args)) + + # Simulate that summary is True & the only argument + args.summary = True + + dm.dataset = ( + lambda root, + train, + download, + transform: DummyClassificationDataset( + root, + train=train, + download=download, + transform=transform, + num_images=20, + ) + ) + + list_dm = dm.make_cross_val_splits(2, 1) + list_model = [] + for i in range(len(list_dm)): + list_model.append( + ResNet( + num_classes=list_dm[i].dm.num_classes, + in_channels=list_dm[i].dm.num_channels, + style="cifar", + loss=nn.CrossEntropyLoss, + optimization_procedure=optim_cifar10_resnet18, + **vars(args), + ) + ) + + cli_main(list_model, list_dm, root, "std", args) + + with ArgvContext("file.py", "--use_cv", "--mixtype", "mixup"): + args = init_args(ResNet, CIFAR10DataModule) + + # datamodule + args.root = str(root / "data") + dm = CIFAR10DataModule(**vars(args)) + + # Simulate that summary is True & the only argument + args.summary = True + + dm.dataset = ( + lambda root, + train, + download, + transform: DummyClassificationDataset( + root, + train=train, + download=download, + transform=transform, + num_images=20, + ) + ) + + list_dm = dm.make_cross_val_splits(2, 1) + list_model = [] + for i in range(len(list_dm)): + list_model.append( + ResNet( + num_classes=list_dm[i].dm.num_classes, + in_channels=list_dm[i].dm.num_channels, + style="cifar", + loss=nn.CrossEntropyLoss, + optimization_procedure=optim_cifar10_resnet18, + **vars(args), + ) + ) + + cli_main(list_model, list_dm, root, "std", args) + + with ArgvContext("file.py", "--use_cv", "--mixtype", "mixup_io"): + args = init_args(ResNet, CIFAR10DataModule) + + # datamodule + args.root = str(root / "data") + dm = CIFAR10DataModule(**vars(args)) + + # Simulate that summary is True & the only argument + args.summary = True + + dm.dataset = ( + lambda root, + train, + download, + transform: DummyClassificationDataset( + root, + train=train, + download=download, + transform=transform, + num_images=20, + ) + ) + + list_dm = dm.make_cross_val_splits(2, 1) + list_model = [] + for i in range(len(list_dm)): + list_model.append( + ResNet( + num_classes=list_dm[i].dm.num_classes, + in_channels=list_dm[i].dm.num_channels, + style="cifar", + loss=nn.CrossEntropyLoss, + optimization_procedure=optim_cifar10_resnet18, + **vars(args), + ) + ) + + cli_main(list_model, list_dm, root, "std", args) + + with ArgvContext("file.py", "--use_cv", "--mixtype", "regmixup"): + args = init_args(ResNet, CIFAR10DataModule) + + # datamodule + args.root = str(root / "data") + dm = CIFAR10DataModule(**vars(args)) + + # Simulate that summary is True & the only argument + args.summary = True + + dm.dataset = ( + lambda root, + train, + download, + transform: DummyClassificationDataset( + root, + train=train, + download=download, + transform=transform, + num_images=20, + ) + ) + + list_dm = dm.make_cross_val_splits(2, 1) + list_model = [] + for i in range(len(list_dm)): + list_model.append( + ResNet( + num_classes=list_dm[i].dm.num_classes, + in_channels=list_dm[i].dm.num_channels, + style="cifar", + loss=nn.CrossEntropyLoss, + optimization_procedure=optim_cifar10_resnet18, + **vars(args), + ) + ) + + cli_main(list_model, list_dm, root, "std", args) + + with ArgvContext( + "file.py", "--use_cv", "--mixtype", "kernel_warping" + ): + args = init_args(ResNet, CIFAR10DataModule) + + # datamodule + args.root = str(root / "data") + dm = CIFAR10DataModule(**vars(args)) + + # Simulate that summary is True & the only argument + args.summary = True + + dm.dataset = ( + lambda root, + train, + download, + transform: DummyClassificationDataset( + root, + train=train, + download=download, + transform=transform, + num_images=20, + ) + ) + + list_dm = dm.make_cross_val_splits(2, 1) + list_model = [] + for i in range(len(list_dm)): + list_model.append( + ResNet( + num_classes=list_dm[i].dm.num_classes, + in_channels=list_dm[i].dm.num_channels, + style="cifar", + loss=nn.CrossEntropyLoss, + optimization_procedure=optim_cifar10_resnet18, + **vars(args), + ) + ) + + cli_main(list_model, list_dm, root, "std", args) + def test_init_args_void(self): with ArgvContext("file.py"): init_args() diff --git a/tests/test_losses.py b/tests/test_losses.py index 96412e30..89a110ad 100644 --- a/tests/test_losses.py +++ b/tests/test_losses.py @@ -5,7 +5,7 @@ from torch import nn from torch_uncertainty.layers.bayesian import BayesLinear -from torch_uncertainty.losses import DECLoss, ELBOLoss, NIGLoss, BetaNLL +from torch_uncertainty.losses import BetaNLL, DECLoss, ELBOLoss, NIGLoss class TestELBOLoss: diff --git a/tests/test_optimization_procedures.py b/tests/test_optimization_procedures.py index 8fb63011..e250b547 100644 --- a/tests/test_optimization_procedures.py +++ b/tests/test_optimization_procedures.py @@ -1,7 +1,7 @@ # ruff: noqa: F401 import pytest -from torch_uncertainty.models.resnet import resnet18, resnet50 +from torch_uncertainty.models.resnet import resnet18, resnet34, resnet50 from torch_uncertainty.models.vgg import vgg16 from torch_uncertainty.models.wideresnet import wideresnet28x10 from torch_uncertainty.optimization_procedures import ( @@ -11,46 +11,57 @@ class TestOptProcedures: - def test_optim_cifar10_resnet18(self): + def test_optim_cifar10(self): procedure = get_procedure("resnet18", "cifar10", "standard") model = resnet18(in_channels=3, num_classes=10) procedure(model) - def test_optim_cifar10_resnet50(self): + procedure = get_procedure("resnet34", "cifar10", "masked") + model = resnet34(in_channels=3, num_classes=100) + procedure(model) + procedure = get_procedure("resnet50", "cifar10", "packed") model = resnet50(in_channels=3, num_classes=10) procedure(model) - def test_optim_cifar10_wideresnet(self): procedure = get_procedure("wideresnet28x10", "cifar10", "batched") model = wideresnet28x10(in_channels=3, num_classes=10) procedure(model) - def test_optim_cifar10_vgg16(self): procedure = get_procedure("vgg16", "cifar10", "standard") model = vgg16(in_channels=3, num_classes=10) procedure(model) - def test_optim_cifar100_resnet18(self): + def test_optim_cifar100(self): procedure = get_procedure("resnet18", "cifar100", "masked") model = resnet18(in_channels=3, num_classes=100) procedure(model) - def test_optim_cifar100_resnet50(self): + procedure = get_procedure("resnet34", "cifar100", "masked") + model = resnet34(in_channels=3, num_classes=100) + procedure(model) + procedure = get_procedure("resnet50", "cifar100") model = resnet50(in_channels=3, num_classes=100) procedure(model) - def test_optim_cifar100_wideresnet(self): procedure = get_procedure("wideresnet28x10", "cifar100") model = wideresnet28x10(in_channels=3, num_classes=100) procedure(model) - def test_optim_cifar100_vgg16(self): procedure = get_procedure("vgg16", "cifar100", "standard") model = vgg16(in_channels=3, num_classes=100) procedure(model) + def test_optim_tinyimagenet(self): + procedure = get_procedure("resnet34", "tiny-imagenet", "standard") + model = resnet34(in_channels=3, num_classes=1000) + procedure(model) + + procedure = get_procedure("resnet50", "tiny-imagenet", "standard") + model = resnet50(in_channels=3, num_classes=1000) + procedure(model) + def test_optim_imagenet_resnet50(self): procedure = get_procedure("resnet50", "imagenet", "standard", "A3") model = resnet50(in_channels=3, num_classes=1000) diff --git a/tests/transforms/test_mixup.py b/tests/transforms/test_mixup.py new file mode 100644 index 00000000..d71dbc99 --- /dev/null +++ b/tests/transforms/test_mixup.py @@ -0,0 +1,79 @@ +from typing import Tuple + +import pytest +import torch + +from torch_uncertainty.transforms import Mixup, MixupIO, RegMixup, WarpingMixup +from torch_uncertainty.transforms.mixup import AbstractMixup + + +@pytest.fixture +def batch_input() -> Tuple[torch.Tensor, torch.Tensor]: + imgs = torch.rand(2, 3, 28, 28) + return imgs, torch.tensor([0, 1]) + + +class TestAbstractMixup: + """Testing AbstractMixup augmentation""" + + def test_abstract_mixup(self, batch_input): + with pytest.raises(NotImplementedError): + AbstractMixup()(*batch_input) + + +class TestMixup: + """Testing Mixup augmentation""" + + def test_batch_mixup(self, batch_input): + mixup = Mixup(alpha=1.0, mode="batch", num_classes=2) + _ = mixup(*batch_input) + + def test_elem_mixup(self, batch_input): + mixup = Mixup(alpha=1.0, mode="elem", num_classes=2) + _ = mixup(*batch_input) + + +class TestMixupIO: + """Testing MixupIO augmentation""" + + def test_batch_mixupio(self, batch_input): + mixup = MixupIO(alpha=1.0, mode="batch", num_classes=2) + _ = mixup(*batch_input) + + def test_elem_mixupio(self, batch_input): + mixup = MixupIO(alpha=1.0, mode="elem", num_classes=2) + _ = mixup(*batch_input) + + +class TestRegMixup: + """Testing RegMixup augmentation""" + + def test_batch_regmixup(self, batch_input): + mixup = RegMixup(alpha=1.0, mode="batch", num_classes=2) + _ = mixup(*batch_input) + + def test_elem_regmixup(self, batch_input): + mixup = RegMixup(alpha=1.0, mode="elem", num_classes=2) + _ = mixup(*batch_input) + + +class TestWarpingMixup: + """Testing WarpingMixup augmentation""" + + def test_batch_kernel_warpingmixup(self, batch_input): + mixup = WarpingMixup( + alpha=1.0, mode="batch", num_classes=2, apply_kernel=True + ) + _ = mixup(*batch_input, batch_input[0]) + + def test_elem_kernel_warpingmixup(self, batch_input): + mixup = WarpingMixup( + alpha=1.0, mode="elem", num_classes=2, apply_kernel=True + ) + _ = mixup(*batch_input, batch_input[0]) + + def test_elem_warpingmixup(self, batch_input): + mixup = WarpingMixup( + alpha=1.0, mode="elem", num_classes=2, apply_kernel=False + ) + _ = mixup(*batch_input, batch_input[0]) diff --git a/tests/transforms/test_transforms.py b/tests/transforms/test_transforms.py index c6e30d34..be18de67 100644 --- a/tests/transforms/test_transforms.py +++ b/tests/transforms/test_transforms.py @@ -1,10 +1,10 @@ from typing import Tuple +import numpy import pytest import torch from PIL import Image -import numpy from torch_uncertainty.transforms import ( AutoContrast, Brightness, diff --git a/torch_uncertainty/__init__.py b/torch_uncertainty/__init__.py index 24dcde90..e3612ae6 100644 --- a/torch_uncertainty/__init__.py +++ b/torch_uncertainty/__init__.py @@ -1,8 +1,10 @@ # ruff: noqa: F401 from argparse import ArgumentParser, Namespace +from collections import defaultdict from pathlib import Path -from typing import Dict, Optional, Type, Union +from typing import Any, Dict, List, Optional, Type, Union +import numpy as np import pytorch_lightning as pl import torch from pytorch_lightning.callbacks import LearningRateMonitor @@ -11,13 +13,12 @@ from pytorch_lightning.loggers.tensorboard import TensorBoardLogger from torchinfo import summary -import numpy as np - +from .datamodules.abstract import AbstractDataModule from .utils import get_version def init_args( - network: Optional[Type[pl.LightningModule]] = None, + network: Any = None, datamodule: Optional[Type[pl.LightningDataModule]] = None, ) -> Namespace: parser = ArgumentParser("torch-uncertainty") @@ -33,6 +34,9 @@ def init_args( default=None, help="Run in test mode. Set to the checkpoint version number to test.", ) + parser.add_argument( + "--ckpt", type=int, default=None, help="The number of the checkpoint" + ) parser.add_argument( "--summary", dest="summary", @@ -50,7 +54,30 @@ def init_args( action="store_true", help="Allow resuming the training (save optimizer's states)", ) - + parser.add_argument( + "--exp_dir", + type=str, + default="logs/", + help="Directory to store experiment files", + ) + parser.add_argument( + "--exp_name", + type=str, + default="", + help="Name of the experiment folder", + ) + parser.add_argument( + "--opt_temp_scaling", + action="store_true", + default=False, + help="Compute optimal temperature on the test set", + ) + parser.add_argument( + "--val_temp_scaling", + action="store_true", + default=False, + help="Compute temperature on the validation set", + ) parser = pl.Trainer.add_argparse_args(parser) if network is not None: parser = network.add_model_specific_args(parser) @@ -62,16 +89,19 @@ def init_args( def cli_main( - network: pl.LightningModule, - datamodule: pl.LightningDataModule, + network: pl.LightningModule | List[pl.LightningModule], + datamodule: AbstractDataModule | List[AbstractDataModule], root: Union[Path, str], net_name: str, args: Namespace, -) -> Dict: +) -> List[Dict]: if isinstance(root, str): root = Path(root) - training_task = datamodule.training_task + if isinstance(datamodule, list): + training_task = datamodule[0].dm.training_task + else: + training_task = datamodule.training_task if training_task == "classification": monitor = "hp/val_acc" mode = "max" @@ -92,57 +122,135 @@ def cli_main( pl.seed_everything(args.seed, workers=True) if args.channels_last: - network = network.to(memory_format=torch.channels_last) - - # logger - tb_logger = TensorBoardLogger( - str(root / "logs"), - name=net_name, - default_hp_metric=False, - log_graph=args.log_graph, - version=args.test, - ) + if isinstance(network, list): + for i in range(len(network)): + network[i] = network[i].to(memory_format=torch.channels_last) + else: + network = network.to(memory_format=torch.channels_last) - # callbacks - save_checkpoints = ModelCheckpoint( - monitor=monitor, - mode=mode, - save_last=True, - save_weights_only=not args.enable_resume, - ) + if args.use_cv: + test_values = [] + for i in range(len(datamodule)): + print( + f"Starting fold {i} out of {args.train_over} of a {args.n_splits}-fold CV." + ) - # Select the best model, monitor the lr and stop if NaN - callbacks = [ - save_checkpoints, - LearningRateMonitor(logging_interval="step"), - EarlyStopping(monitor=monitor, patience=np.inf, check_finite=True), - ] - # trainer - trainer = pl.Trainer.from_argparse_args( - args, - callbacks=callbacks, - logger=tb_logger, - deterministic=(args.seed is not None), - ) + # logger + tb_logger = TensorBoardLogger( + str(root), + name=net_name, + default_hp_metric=False, + log_graph=args.log_graph, + version=f"fold_{i}", + ) - if args.summary: - summary(network, input_size=list(datamodule.input_shape).insert(0, 1)) - test_values = {} - elif args.test is not None: # coverage: ignore - if args.test >= 0: - ckpt_file, _ = get_version( - root=(root / "logs" / net_name), version=args.test + # callbacks + save_checkpoints = ModelCheckpoint( + dirpath=tb_logger.log_dir, + monitor=monitor, + mode=mode, + save_last=True, + save_weights_only=not args.enable_resume, ) - test_values = trainer.test( - network, datamodule=datamodule, ckpt_path=str(ckpt_file) + + # Select the best model, monitor the lr and stop if NaN + callbacks = [ + save_checkpoints, + LearningRateMonitor(logging_interval="step"), + EarlyStopping( + monitor=monitor, patience=np.inf, check_finite=True + ), + ] + + trainer = pl.Trainer.from_argparse_args( + args, + callbacks=callbacks, + logger=tb_logger, + deterministic=(args.seed is not None), + inference_mode=not ( + args.opt_temp_scaling or args.val_temp_scaling + ), ) - else: - test_values = trainer.test(network, datamodule=datamodule) + if args.summary: + summary( + network[i], + input_size=list(datamodule[i].dm.input_shape).insert(0, 1), + ) + test_values.append({}) + else: + trainer.fit(network[i], datamodule[i]) + test_values.append( + trainer.test(datamodule=datamodule[i], ckpt_path="last")[0] + ) + + all_test_values = defaultdict(list) + for test_value in test_values: + for key in test_value: + all_test_values[key].append(test_value[key]) + + avg_test_values = {} + for key in all_test_values: + avg_test_values[key] = np.mean(all_test_values[key]) + + return [avg_test_values] else: - # training and testing - trainer.fit(network, datamodule) - if args.fast_dev_run is False: - test_values = trainer.test(datamodule=datamodule, ckpt_path="best") + # logger + tb_logger = TensorBoardLogger( + str(root), + name=net_name, + default_hp_metric=False, + log_graph=args.log_graph, + version=args.test, + ) + + # callbacks + save_checkpoints = ModelCheckpoint( + monitor=monitor, + mode=mode, + save_last=True, + save_weights_only=not args.enable_resume, + ) + + # Select the best model, monitor the lr and stop if NaN + callbacks = [ + save_checkpoints, + LearningRateMonitor(logging_interval="step"), + EarlyStopping(monitor=monitor, patience=np.inf, check_finite=True), + ] + + # trainer + trainer = pl.Trainer.from_argparse_args( + args, + callbacks=callbacks, + logger=tb_logger, + deterministic=(args.seed is not None), + inference_mode=not (args.opt_temp_scaling or args.val_temp_scaling), + ) + if args.summary: + summary( + network, + input_size=list(datamodule.input_shape).insert(0, 1), + ) + test_values = [{}] + elif args.test is not None: + if args.test >= 0: + ckpt_file, _ = get_version( + root=(root / net_name), + version=args.test, + checkpoint=args.ckpt, + ) + test_values = trainer.test( + network, datamodule=datamodule, ckpt_path=str(ckpt_file) + ) + else: + test_values = trainer.test(network, datamodule=datamodule) else: - test_values = {} - return test_values + # training and testing + trainer.fit(network, datamodule) + if args.fast_dev_run is False: + test_values = trainer.test( + datamodule=datamodule, ckpt_path="best" + ) + else: + test_values = [{}] + return test_values diff --git a/torch_uncertainty/baselines/regression/mlp.py b/torch_uncertainty/baselines/regression/mlp.py index 63759dd9..6bda52d4 100644 --- a/torch_uncertainty/baselines/regression/mlp.py +++ b/torch_uncertainty/baselines/regression/mlp.py @@ -1,6 +1,6 @@ from argparse import ArgumentParser from pathlib import Path -from typing import Any, List, Literal, Optional, Union +from typing import Any, List, Literal, Optional, Type, Union import torch from pytorch_lightning import LightningModule @@ -26,7 +26,7 @@ def __new__( cls, num_outputs: int, in_features: int, - loss: nn.Module, + loss: Type[nn.Module], optimization_procedure: Any, version: Literal["vanilla", "packed"], hidden_dims: List[int], diff --git a/torch_uncertainty/datamodules/abstract.py b/torch_uncertainty/datamodules/abstract.py new file mode 100644 index 00000000..ad414ef5 --- /dev/null +++ b/torch_uncertainty/datamodules/abstract.py @@ -0,0 +1,210 @@ +from argparse import ArgumentParser +from pathlib import Path +from typing import Any, List, Optional, Union + +from numpy.typing import ArrayLike +from pytorch_lightning import LightningDataModule +from sklearn.model_selection import StratifiedKFold +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.sampler import SubsetRandomSampler + + +class AbstractDataModule(LightningDataModule): + training_task: str + + def __init__( + self, + root: Union[str, Path], + batch_size: int, + num_workers: int = 1, + pin_memory: bool = True, + persistent_workers: bool = True, + **kwargs, + ) -> None: + super().__init__() + + if isinstance(root, str): + root = Path(root) + self.root: Path = root + self.batch_size = batch_size + self.num_workers = num_workers + + self.pin_memory = pin_memory + self.persistent_workers = persistent_workers + + def setup(self, stage: Optional[str] = None) -> None: + raise NotImplementedError() + + def get_train_set(self) -> Dataset: + return self.train + + def get_test_set(self) -> Dataset: + return self.test + + def get_val_set(self) -> Dataset: + return self.val + + def train_dataloader(self) -> DataLoader: + r"""Get the training dataloader. + + Return: + DataLoader: training dataloader. + """ + return self._data_loader(self.train, shuffle=True) + + def val_dataloader(self) -> DataLoader: + r"""Get the validation dataloader. + + Return: + DataLoader: validation dataloader. + """ + return self._data_loader(self.val) + + def test_dataloader(self) -> List[DataLoader]: + r"""Get test dataloaders. + + Return: + List[DataLoader]: test set for in distribution data + and out-of-distribution data. + """ + dataloader = [self._data_loader(self.test)] + return dataloader + + def _data_loader( + self, dataset: Dataset, shuffle: bool = False + ) -> DataLoader: + """Create a dataloader for a given dataset. + + Args: + dataset (Dataset): Dataset to create a dataloader for. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults + to False. + + Return: + DataLoader: Dataloader for the given dataset. + """ + return DataLoader( + dataset, + batch_size=self.batch_size, + shuffle=shuffle, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers, + ) + + # These two functions have to be defined in each datamodule + # by setting the correct path to the matrix of data for each dataset. + # It is generally "Dataset.samples" or "Dataset.data" + # They are used for constructing cross validation splits + def _get_train_data(self) -> ArrayLike: + raise NotImplementedError() + + def _get_train_targets(self) -> ArrayLike: + raise NotImplementedError() + + def make_cross_val_splits( + self, n_splits: int = 10, train_over: int = 4 + ) -> List: + self.setup("fit") + skf = StratifiedKFold(n_splits) + cv_dm = [] + + for fold, (train_idx, val_idx) in enumerate( + skf.split(self._get_train_data(), self._get_train_targets()) + ): + if fold >= train_over: + break + + fold_dm = CrossValDataModule( + root=self.root, + train_idx=train_idx, + val_idx=val_idx, + datamodule=self, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers, + ) + + cv_dm.append(fold_dm) + + return cv_dm + + @classmethod + def add_argparse_args( + cls, + parent_parser: ArgumentParser, + **kwargs: Any, + ) -> ArgumentParser: + p = parent_parser.add_argument_group("datamodule") + p.add_argument("--root", type=str, default="./data/") + p.add_argument("--batch_size", type=int, default=128) + p.add_argument("--val_split", type=float, default=0.0) + p.add_argument("--num_workers", type=int, default=4) + p.add_argument("--use_cv", action="store_true") + p.add_argument("--n_splits", type=int, default=10) + p.add_argument("--train_over", type=int, default=4) + return parent_parser + + +class CrossValDataModule(AbstractDataModule): + def __init__( + self, + root: str | Path, + train_idx: ArrayLike, + val_idx: ArrayLike, + datamodule: AbstractDataModule, + batch_size: int, + num_workers: int = 1, + pin_memory: bool = True, + persistent_workers: bool = True, + **kwargs, + ) -> None: + super().__init__( + root, + batch_size, + num_workers, + pin_memory, + persistent_workers, + **kwargs, + ) + + self.train_idx = train_idx + self.val_idx = val_idx + self.dm = datamodule + + def setup(self, stage: str | None = None) -> None: + if stage == "fit" or stage is None: + self.train = self.dm.train + self.val = self.dm.val + elif stage == "test": + self.test = self.val + + def _data_loader(self, dataset: Dataset, idx: ArrayLike) -> DataLoader: + return DataLoader( + dataset=dataset, + sampler=SubsetRandomSampler(idx), + shuffle=False, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers, + ) + + def get_train_set(self) -> Dataset: + return self.dm.train + + def get_test_set(self) -> Dataset: + return self.dm.val + + def get_val_set(self) -> Dataset: + return self.dm.val + + def train_dataloader(self) -> DataLoader: + return self._data_loader(self.dm.get_train_set(), self.train_idx) + + def val_dataloader(self) -> DataLoader: + return self._data_loader(self.dm.get_train_set(), self.val_idx) + + def test_dataloader(self) -> DataLoader: + return self._data_loader(self.dm.get_train_set(), self.val_idx) diff --git a/torch_uncertainty/datamodules/cifar10.py b/torch_uncertainty/datamodules/cifar10.py index a2f3588a..a4b5990e 100644 --- a/torch_uncertainty/datamodules/cifar10.py +++ b/torch_uncertainty/datamodules/cifar10.py @@ -2,19 +2,21 @@ from pathlib import Path from typing import Any, List, Literal, Optional, Union +import numpy as np import torchvision.transforms as T -from pytorch_lightning import LightningDataModule +from numpy.typing import ArrayLike from timm.data.auto_augment import rand_augment_transform from torch import nn -from torch.utils.data import DataLoader, Dataset, random_split +from torch.utils.data import DataLoader, random_split from torchvision.datasets import CIFAR10, SVHN from ..datasets import AggregatedDataset from ..datasets.classification import CIFAR10C, CIFAR10H from ..transforms import Cutout +from .abstract import AbstractDataModule -class CIFAR10DataModule(LightningDataModule): +class CIFAR10DataModule(AbstractDataModule): """DataModule for CIFAR10. Args: @@ -58,19 +60,17 @@ def __init__( persistent_workers: bool = True, **kwargs, ) -> None: - super().__init__() + super().__init__( + root=root, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) - if isinstance(root, str): - root = Path(root) - self.root: Path = root - self.evaluate_ood = evaluate_ood - self.batch_size = batch_size self.val_split = val_split - self.num_workers = num_workers self.num_dataloaders = num_dataloaders - - self.pin_memory = pin_memory - self.persistent_workers = persistent_workers + self.evaluate_ood = evaluate_ood if test_alt == "c": self.dataset = CIFAR10C @@ -202,48 +202,29 @@ def train_dataloader(self) -> DataLoader: else: return self._data_loader(self.train, shuffle=True) - def val_dataloader(self) -> DataLoader: - r"""Gets the validation dataloader for CIFAR10. - - Returns: - DataLoader: CIFAR10 validation dataloader. - """ - return self._data_loader(self.val) - def test_dataloader(self) -> List[DataLoader]: - r"""Get the test dataloaders for CIFAR10. + r"""Get test dataloaders. Return: - List[DataLoader]: Dataloaders of the CIFAR10 test set (in - distribution data) and SVHN test split (out-of-distribution - data). + List[DataLoader]: test set for in distribution data + and out-of-distribution data. """ dataloader = [self._data_loader(self.test)] if self.evaluate_ood: dataloader.append(self._data_loader(self.ood)) return dataloader - def _data_loader( - self, dataset: Dataset, shuffle: bool = False - ) -> DataLoader: - """Create a dataloader for a given dataset. - - Args: - dataset (Dataset): Dataset to create a dataloader for. - shuffle (bool, optional): Whether to shuffle the dataset. Defaults - to False. + def _get_train_data(self) -> ArrayLike: + if self.val_split: + return self.train.dataset.data[self.train.indices] + else: + return self.train.data - Return: - DataLoader: Dataloader for the given dataset. - """ - return DataLoader( - dataset, - batch_size=self.batch_size, - shuffle=shuffle, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - persistent_workers=self.persistent_workers, - ) + def _get_train_targets(self) -> ArrayLike: + if self.val_split: + return np.array(self.train.dataset.targets)[self.train.indices] + else: + return np.array(self.train.targets) @classmethod def add_argparse_args( @@ -251,16 +232,14 @@ def add_argparse_args( parent_parser: ArgumentParser, **kwargs: Any, ) -> ArgumentParser: - p = parent_parser.add_argument_group("datamodule") - p.add_argument("--root", type=str, default="./data/") - p.add_argument("--batch_size", type=int, default=128) - p.add_argument("--val_split", type=float, default=0.0) - p.add_argument("--num_workers", type=int, default=4) - p.add_argument("--evaluate_ood", action="store_true") + p = super().add_argparse_args(parent_parser) + + # Arguments for CIFAR10 p.add_argument("--cutout", type=int, default=0) p.add_argument("--auto_augment", type=str) p.add_argument("--test_alt", choices=["c", "h"], default=None) p.add_argument( "--severity", dest="corruption_severity", type=int, default=None ) + p.add_argument("--evaluate_ood", action="store_true") return parent_parser diff --git a/torch_uncertainty/datamodules/cifar100.py b/torch_uncertainty/datamodules/cifar100.py index 78cbfa9a..4a596c41 100644 --- a/torch_uncertainty/datamodules/cifar100.py +++ b/torch_uncertainty/datamodules/cifar100.py @@ -2,20 +2,22 @@ from pathlib import Path from typing import Any, List, Literal, Optional, Union +import numpy as np import torch import torchvision.transforms as T -from pytorch_lightning import LightningDataModule +from numpy.typing import ArrayLike from timm.data.auto_augment import rand_augment_transform from torch import nn -from torch.utils.data import DataLoader, Dataset, random_split +from torch.utils.data import DataLoader, random_split from torchvision.datasets import CIFAR100, SVHN from ..datasets import AggregatedDataset from ..datasets.classification import CIFAR100C from ..transforms import Cutout +from .abstract import AbstractDataModule -class CIFAR100DataModule(LightningDataModule): +class CIFAR100DataModule(AbstractDataModule): """DataModule for CIFAR100. Args: @@ -60,21 +62,18 @@ def __init__( persistent_workers: bool = True, **kwargs, ) -> None: - super().__init__() - - if isinstance(root, str): - root = Path(root) + super().__init__( + root=root, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) - self.root: Path = root self.evaluate_ood = evaluate_ood - self.batch_size = batch_size self.val_split = val_split - self.num_workers = num_workers self.num_dataloaders = num_dataloaders - self.pin_memory = pin_memory - self.persistent_workers = persistent_workers - if test_alt == "c": self.dataset = CIFAR100C else: @@ -203,48 +202,29 @@ def train_dataloader(self) -> DataLoader: else: return self._data_loader(self.train, shuffle=True) - def val_dataloader(self) -> DataLoader: - """Get the validation dataloader for CIFAR100. - - Return: - DataLoader: CIFAR100 validation dataloader. - """ - return self._data_loader(self.val) - def test_dataloader(self) -> List[DataLoader]: - """Get the test dataloaders for CIFAR100. + r"""Get test dataloaders. Return: - List[DataLoader]: Dataloaders of the CIFAR100 test set (in - distribution data) and SVHN test split (out-of-distribution - data). + List[DataLoader]: test set for in distribution data + and out-of-distribution data. """ dataloader = [self._data_loader(self.test)] if self.evaluate_ood: dataloader.append(self._data_loader(self.ood)) return dataloader - def _data_loader( - self, dataset: Dataset, shuffle: bool = False - ) -> DataLoader: - """Create a dataloader for a given dataset. - - Args: - dataset (Dataset): Dataset to create a dataloader for. - shuffle (bool, optional): Whether to shuffle the dataset. Defaults - to False. + def _get_train_data(self) -> ArrayLike: + if self.val_split: + return self.train.dataset.data[self.train.indices] + else: + return self.train.data - Return: - DataLoader: Dataloader for the given dataset. - """ - return DataLoader( - dataset, - batch_size=self.batch_size, - shuffle=shuffle, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - persistent_workers=self.persistent_workers, - ) + def _get_train_targets(self) -> ArrayLike: + if self.val_split: + return np.array(self.train.dataset.targets)[self.train.indices] + else: + return np.array(self.train.targets) @classmethod def add_argparse_args( @@ -252,12 +232,9 @@ def add_argparse_args( parent_parser: ArgumentParser, **kwargs: Any, ) -> ArgumentParser: - p = parent_parser.add_argument_group("datamodule") - p.add_argument("--root", type=str, default="./data/") - p.add_argument("--batch_size", type=int, default=128) - p.add_argument("--val_split", type=float, default=0.0) - p.add_argument("--num_workers", type=int, default=4) - p.add_argument("--evaluate_ood", action="store_true") + p = super().add_argparse_args(parent_parser) + + # Arguments for CIFAR100 p.add_argument("--cutout", type=int, default=0) p.add_argument("--randaugment", dest="randaugment", action="store_true") p.add_argument("--auto_augment", type=str) @@ -265,4 +242,5 @@ def add_argparse_args( p.add_argument( "--severity", dest="corruption_severity", type=int, default=1 ) + p.add_argument("--evaluate_ood", action="store_true") return parent_parser diff --git a/torch_uncertainty/datamodules/tiny_imagenet.py b/torch_uncertainty/datamodules/tiny_imagenet.py index f62e6f6e..e10b8d5d 100644 --- a/torch_uncertainty/datamodules/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/tiny_imagenet.py @@ -3,16 +3,17 @@ from typing import Any, List, Optional, Union import torchvision.transforms as T -from pytorch_lightning import LightningDataModule +from numpy.typing import ArrayLike from timm.data.auto_augment import rand_augment_transform from torch import nn -from torch.utils.data import ConcatDataset, DataLoader, Dataset +from torch.utils.data import ConcatDataset, DataLoader from torchvision.datasets import DTD, SVHN from ..datasets.classification import ImageNetO, TinyImageNet +from .abstract import AbstractDataModule -class TinyImageNetDataModule(LightningDataModule): +class TinyImageNetDataModule(AbstractDataModule): num_classes = 200 num_channels = 3 training_task = "classification" @@ -29,17 +30,16 @@ def __init__( persistent_workers: bool = True, **kwargs, ) -> None: - super().__init__() + super().__init__( + root=root, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) # TODO: COMPUTE STATS - if isinstance(root, str): - root = Path(root) - self.root: Path = root self.evaluate_ood = evaluate_ood - self.batch_size = batch_size - self.num_workers = num_workers - self.pin_memory = pin_memory - self.persistent_workers = persistent_workers self.ood_ds = ood_ds self.dataset = TinyImageNet @@ -72,8 +72,7 @@ def __init__( self.transform_test = T.Compose( [ - T.Resize(72), - T.CenterCrop(64), + T.Resize(64), T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ] @@ -169,55 +168,23 @@ def setup(self, stage: Optional[str] = None) -> None: transform=self.transform_test, ) - def train_dataloader(self) -> DataLoader: - r"""Get the training dataloader for TinyImageNet. - - Return: - DataLoader: TinyImageNet training dataloader. - """ - return self._data_loader(self.train, shuffle=True) - - def val_dataloader(self) -> DataLoader: - r"""Get the validation dataloader for TinyImageNet. - - Return: - DataLoader: TinyImageNet validation dataloader. - """ - return self._data_loader(self.val) - def test_dataloader(self) -> List[DataLoader]: - r"""Get test dataloaders for TinyImageNet. + r"""Get test dataloaders. Return: - List[DataLoader]: TinyImageNet test set (in distribution data) and - SVHN test split (out-of-distribution data). + List[DataLoader]: test set for in distribution data + and out-of-distribution data. """ dataloader = [self._data_loader(self.test)] if self.evaluate_ood: dataloader.append(self._data_loader(self.ood)) return dataloader - def _data_loader( - self, dataset: Dataset, shuffle: bool = False - ) -> DataLoader: - """Create a dataloader for a given dataset. - - Args: - dataset (Dataset): Dataset to create a dataloader for. - shuffle (bool, optional): Whether to shuffle the dataset. Defaults - to False. + def _get_train_data(self) -> ArrayLike: + return self.train.samples - Return: - DataLoader: Dataloader for the given dataset. - """ - return DataLoader( - dataset, - batch_size=self.batch_size, - shuffle=shuffle, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - persistent_workers=self.persistent_workers, - ) + def _get_train_targets(self) -> ArrayLike: + return self.train.label_data @classmethod def add_argparse_args( @@ -225,12 +192,11 @@ def add_argparse_args( parent_parser: ArgumentParser, **kwargs: Any, ) -> ArgumentParser: - p = parent_parser.add_argument_group("datamodule") - p.add_argument("--root", type=str, default="./data/") - p.add_argument("--batch_size", type=int, default=256) - p.add_argument("--num_workers", type=int, default=4) - p.add_argument("--evaluate_ood", action="store_true") + p = super().add_argparse_args(parent_parser) + + # Arguments for Tiny Imagenet p.add_argument( "--rand_augment", dest="rand_augment_opt", type=str, default=None ) + p.add_argument("--evaluate_ood", action="store_true") return parent_parser diff --git a/torch_uncertainty/datamodules/uci_regression.py b/torch_uncertainty/datamodules/uci_regression.py index 48de2301..d31f5ed3 100644 --- a/torch_uncertainty/datamodules/uci_regression.py +++ b/torch_uncertainty/datamodules/uci_regression.py @@ -3,14 +3,14 @@ from pathlib import Path from typing import Any, Optional, Tuple, Union -from pytorch_lightning import LightningDataModule from torch import Generator -from torch.utils.data import DataLoader, Dataset, random_split +from torch.utils.data import random_split from ..datasets.regression import UCIRegression +from .abstract import AbstractDataModule -class UCIDataModule(LightningDataModule): +class UCIDataModule(AbstractDataModule): """The UCI regression datasets. Args: @@ -45,17 +45,15 @@ def __init__( split_seed: int = 42, **kwargs, ) -> None: - super().__init__() + super().__init__( + root=root, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) - if isinstance(root, str): - root = Path(root) - self.root: Path = root - self.batch_size = batch_size self.val_split = val_split - self.num_workers = num_workers - - self.pin_memory = pin_memory - self.persistent_workers = persistent_workers self.dataset = partial( UCIRegression, dataset_name=dataset_name, seed=split_seed @@ -85,51 +83,14 @@ def setup(self, stage: Optional[str] = None) -> None: if self.val_split == 0: self.val = self.test - def train_dataloader(self) -> DataLoader: - """Get the training dataloader for UCI Regression. - - Return: - DataLoader: UCI Regression training dataloader. - """ - return self._data_loader(self.train, shuffle=True) - - def val_dataloader(self) -> DataLoader: - """Get the validation dataloader for UCI Regression. - - Return: - DataLoader: UCI Regression validation dataloader. - """ - return self._data_loader(self.val) - - def test_dataloader(self) -> DataLoader: - """Get the test dataloader for UCI Regression. + # Change by default test_dataloader -> List[DataLoader] + # def test_dataloader(self) -> DataLoader: + # """Get the test dataloader for UCI Regression. - Return: - DataLoader: UCI Regression test dataloader. - """ - return self._data_loader(self.test) - - def _data_loader( - self, dataset: Dataset, shuffle: bool = False - ) -> DataLoader: - """Create a dataloader for a given dataset. - - Args: - dataset (Dataset): Dataset to create a dataloader for. - shuffle (bool, optional): Whether to shuffle the dataset. Defaults - to False. - - Return: - DataLoader: Dataloader for the given dataset. - """ - return DataLoader( - dataset, - batch_size=self.batch_size, - shuffle=shuffle, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - persistent_workers=self.persistent_workers, - ) + # Return: + # DataLoader: UCI Regression test dataloader. + # """ + # return self._data_loader(self.test) @classmethod def add_argparse_args( @@ -137,9 +98,6 @@ def add_argparse_args( parent_parser: ArgumentParser, **kwargs: Any, ) -> ArgumentParser: - p = parent_parser.add_argument_group("datamodule") - p.add_argument("--root", type=str, default="./data/") - p.add_argument("--batch_size", type=int, default=128) - p.add_argument("--val_split", type=float, default=0) - p.add_argument("--num_workers", type=int, default=4) + super().add_argparse_args(parent_parser) + return parent_parser diff --git a/torch_uncertainty/datasets/classification/cifar/cifar_c.py b/torch_uncertainty/datasets/classification/cifar/cifar_c.py index 284435fb..f670da50 100644 --- a/torch_uncertainty/datasets/classification/cifar/cifar_c.py +++ b/torch_uncertainty/datasets/classification/cifar/cifar_c.py @@ -2,14 +2,13 @@ from pathlib import Path from typing import Any, Callable, Optional, Tuple +import numpy as np from torchvision.datasets import VisionDataset from torchvision.datasets.utils import ( check_integrity, download_and_extract_archive, ) -import numpy as np - class CIFAR10C(VisionDataset): """The corrupted CIFAR-10-C Dataset. diff --git a/torch_uncertainty/datasets/classification/cifar/cifar_h.py b/torch_uncertainty/datasets/classification/cifar/cifar_h.py index a3f17435..f4127055 100644 --- a/torch_uncertainty/datasets/classification/cifar/cifar_h.py +++ b/torch_uncertainty/datasets/classification/cifar/cifar_h.py @@ -1,12 +1,11 @@ import os from typing import Any, Callable, Optional +import numpy as np import torch from torchvision.datasets import CIFAR10 from torchvision.datasets.utils import check_integrity, download_url -import numpy as np - class CIFAR10H(CIFAR10): """`CIFAR-10H `_ Dataset. diff --git a/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet.py b/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet.py index 6e67f174..a37f43f2 100644 --- a/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet.py +++ b/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet.py @@ -3,12 +3,11 @@ from pathlib import Path from typing import Callable, Literal, Optional +import numpy as np import torch from PIL import Image from torch.utils.data import Dataset -import numpy as np - class TinyImageNet(Dataset): """Inspired by diff --git a/torch_uncertainty/datasets/classification/mnist_c.py b/torch_uncertainty/datasets/classification/mnist_c.py index d07373af..b6ce0d04 100644 --- a/torch_uncertainty/datasets/classification/mnist_c.py +++ b/torch_uncertainty/datasets/classification/mnist_c.py @@ -2,14 +2,13 @@ from pathlib import Path from typing import Any, Callable, Literal, Optional, Tuple +import numpy as np from torchvision.datasets import VisionDataset from torchvision.datasets.utils import ( check_integrity, download_and_extract_archive, ) -import numpy as np - class MNISTC(VisionDataset): """The corrupted MNIST-C Dataset. diff --git a/torch_uncertainty/layers/bayesian/sampler.py b/torch_uncertainty/layers/bayesian/sampler.py index d15f2f4f..01ae5791 100644 --- a/torch_uncertainty/layers/bayesian/sampler.py +++ b/torch_uncertainty/layers/bayesian/sampler.py @@ -1,10 +1,9 @@ from typing import Optional +import numpy as np import torch from torch import Tensor, distributions, nn -import numpy as np - class TrainableDistribution(nn.Module): lsqrt2pi = torch.tensor(np.log(np.sqrt(2 * np.pi))) diff --git a/torch_uncertainty/layers/masksembles.py b/torch_uncertainty/layers/masksembles.py index 834d1d4f..ade2737b 100644 --- a/torch_uncertainty/layers/masksembles.py +++ b/torch_uncertainty/layers/masksembles.py @@ -2,12 +2,11 @@ from typing import Any, Union +import numpy as np import torch from torch import Tensor, nn from torch.nn.common_types import _size_2_t -import numpy as np - def _generate_masks(m: int, n: int, s: float) -> np.ndarray: """Generates set of binary masks with properties defined by n, m, s params. diff --git a/torch_uncertainty/metrics/fpr95.py b/torch_uncertainty/metrics/fpr95.py index 5b870b3b..7064c050 100644 --- a/torch_uncertainty/metrics/fpr95.py +++ b/torch_uncertainty/metrics/fpr95.py @@ -1,14 +1,13 @@ from typing import List +import numpy as np import torch +from numpy.typing import ArrayLike from torch import Tensor from torchmetrics import Metric from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat -import numpy as np -from numpy.typing import ArrayLike - def stable_cumsum(arr: ArrayLike, rtol: float = 1e-05, atol: float = 1e-08): """ diff --git a/torch_uncertainty/models/resnet/std.py b/torch_uncertainty/models/resnet/std.py index c8623c4e..2e18a54d 100644 --- a/torch_uncertainty/models/resnet/std.py +++ b/torch_uncertainty/models/resnet/std.py @@ -324,6 +324,18 @@ def forward(self, x: Tensor) -> Tensor: out = self.linear(out) return out + def feats_forward(self, x: Tensor) -> Tensor: + x = self.handle_dropout(x) + out = F.relu(self.bn1(self.conv1(x))) + out = self.optional_pool(out) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = self.pool(out) + out = self.flatten(out) + return out + def handle_dropout(self, x: Tensor) -> Tensor: if self.num_estimators is not None: if not self.training: diff --git a/torch_uncertainty/models/wideresnet/packed.py b/torch_uncertainty/models/wideresnet/packed.py index bd4372ab..242e8116 100644 --- a/torch_uncertainty/models/wideresnet/packed.py +++ b/torch_uncertainty/models/wideresnet/packed.py @@ -6,7 +6,6 @@ from ...layers import PackedConv2d, PackedLinear - __all__ = [ "packed_wideresnet28x10", ] diff --git a/torch_uncertainty/optimization_procedures.py b/torch_uncertainty/optimization_procedures.py index 3bb728aa..900afe4a 100644 --- a/torch_uncertainty/optimization_procedures.py +++ b/torch_uncertainty/optimization_procedures.py @@ -17,6 +17,10 @@ "optim_imagenet_resnet50", "optim_imagenet_resnet50_A3", "optim_regression", + "optim_cifar10_resnet34", + "optim_cifar100_resnet34", + "optim_tinyimagenet_resnet34", + "optim_tinyimagenet_resnet50", ] @@ -227,6 +231,92 @@ def optim_imagenet_resnet50_A3( } +def optim_cifar10_resnet34( + model: nn.Module, +) -> Dict[str, Union[Optimizer, LRScheduler]]: + optimizer = optim.SGD( + model.parameters(), + lr=0.1, + momentum=0.9, + weight_decay=1e-4, + nesterov=True, + ) + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=[100, 150], + gamma=0.1, + ) + return {"optimizer": optimizer, "lr_scheduler": scheduler} + + +def optim_cifar100_resnet34( + model: nn.Module, +) -> Dict[str, Union[Optimizer, LRScheduler]]: + optimizer = optim.SGD( + model.parameters(), + lr=0.1, + momentum=0.9, + weight_decay=1e-4, + nesterov=True, + ) + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=[100, 150], + gamma=0.1, + ) + return {"optimizer": optimizer, "lr_scheduler": scheduler} + + +def optim_tinyimagenet_resnet34( + model: nn.Module, +) -> Dict[str, Union[Optimizer, LRScheduler]]: + """Optimization procedure from 'The Devil is in the Margin: Margin-based + Label Smoothing for Network Calibration', + (CVPR 2022, https://arxiv.org/abs/2111.15430): + "We train for 100 epochs with a learning rate of 0.1 for the first + 40 epochs, of 0.01 for the next 20 epochs and of 0.001 for the last + 40 epochs." + """ + optimizer = optim.SGD( + model.parameters(), + lr=0.1, + momentum=0.9, + weight_decay=1e-4, + nesterov=True, + ) + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=[40, 60], + gamma=0.1, + ) + return {"optimizer": optimizer, "lr_scheduler": scheduler} + + +def optim_tinyimagenet_resnet50( + model: nn.Module, +) -> Dict[str, Union[Optimizer, LRScheduler]]: + """Optimization procedure from 'The Devil is in the Margin: Margin-based + Label Smoothing for Network Calibration', + (CVPR 2022, https://arxiv.org/abs/2111.15430): + "We train for 100 epochs with a learning rate of 0.1 for the first + 40 epochs, of 0.01 for the next 20 epochs and of 0.001 for the last + 40 epochs." + """ + optimizer = optim.SGD( + model.parameters(), + lr=0.1, + momentum=0.9, + weight_decay=1e-4, + nesterov=True, + ) + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=[40, 60], + gamma=0.1, + ) + return {"optimizer": optimizer, "lr_scheduler": scheduler} + + def optim_regression( model: nn.Module, learning_rate: float = 1e-2, @@ -307,12 +397,21 @@ def get_procedure( elif ds_name == "cifar100": procedure = optim_cifar100_resnet18 else: - raise NotImplementedError(f"No recipe for dataset:{ds_name}.") + raise NotImplementedError(f"Dataset {ds_name} not implemented.") + elif arch_name == "resnet34": + if ds_name == "cifar10": + procedure = optim_cifar10_resnet34 + elif ds_name == "cifar100": + procedure = optim_cifar100_resnet34 + elif ds_name == "tiny-imagenet": + procedure = optim_tinyimagenet_resnet34 elif arch_name == "resnet50": if ds_name == "cifar10": procedure = optim_cifar10_resnet50 elif ds_name == "cifar100": procedure = optim_cifar100_resnet50 + elif ds_name == "tiny-imagenet": + procedure = optim_tinyimagenet_resnet50 elif ds_name == "imagenet": if imagenet_recipe is not None and imagenet_recipe == "A3": procedure = optim_imagenet_resnet50_A3 diff --git a/torch_uncertainty/post_processing/calibration/matrix_scaler.py b/torch_uncertainty/post_processing/calibration/matrix_scaler.py index 9ac30d2b..9fc5b14c 100644 --- a/torch_uncertainty/post_processing/calibration/matrix_scaler.py +++ b/torch_uncertainty/post_processing/calibration/matrix_scaler.py @@ -35,7 +35,7 @@ def __init__( init_b: float = 0, lr: float = 0.1, max_iter: int = 200, - device: Optional[Literal["cpu", "cuda"]] = None, + device: Optional[Literal["cpu", "cuda"] | torch.device] = None, ) -> None: super().__init__(lr=lr, max_iter=max_iter, device=device) diff --git a/torch_uncertainty/post_processing/calibration/scaler.py b/torch_uncertainty/post_processing/calibration/scaler.py index 6e777a38..e20cb918 100644 --- a/torch_uncertainty/post_processing/calibration/scaler.py +++ b/torch_uncertainty/post_processing/calibration/scaler.py @@ -29,7 +29,7 @@ def __init__( self, lr: float = 0.1, max_iter: int = 100, - device: Optional[Literal["cpu", "cuda"]] = None, + device: Optional[Literal["cpu", "cuda"] | torch.device] = None, ) -> None: super().__init__() self.device = device diff --git a/torch_uncertainty/post_processing/calibration/temperature_scaler.py b/torch_uncertainty/post_processing/calibration/temperature_scaler.py index 3ae23f05..5a10c5ef 100644 --- a/torch_uncertainty/post_processing/calibration/temperature_scaler.py +++ b/torch_uncertainty/post_processing/calibration/temperature_scaler.py @@ -29,7 +29,7 @@ def __init__( init_val: float = 1, lr: float = 0.1, max_iter: int = 100, - device: Optional[Literal["cpu", "cuda"]] = None, + device: Optional[Literal["cpu", "cuda"] | torch.device] = None, ) -> None: super().__init__(lr=lr, max_iter=max_iter, device=device) diff --git a/torch_uncertainty/post_processing/calibration/vector_scaler.py b/torch_uncertainty/post_processing/calibration/vector_scaler.py index fb37d049..9cc55f0e 100644 --- a/torch_uncertainty/post_processing/calibration/vector_scaler.py +++ b/torch_uncertainty/post_processing/calibration/vector_scaler.py @@ -35,7 +35,7 @@ def __init__( init_b: float = 0, lr: float = 0.1, max_iter: int = 200, - device: Optional[Literal["cpu", "cuda"]] = None, + device: Optional[Literal["cpu", "cuda"] | torch.device] = None, ) -> None: super().__init__(lr=lr, max_iter=max_iter, device=device) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index c0f2a702..4b86dc8e 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -1,6 +1,6 @@ from argparse import ArgumentParser, Namespace from functools import partial -from typing import Any, List, Optional, Tuple, Type, Union +from typing import Any, Callable, List, Optional, Tuple, Type, Union import pytorch_lightning as pl import torch @@ -9,7 +9,7 @@ from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.utilities.memory import get_model_size_mb from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT -from timm.data import Mixup +from timm.data import Mixup as timm_Mixup from torch import Tensor, nn from torchmetrics import Accuracy, CalibrationError, MetricCollection from torchmetrics.classification import ( @@ -31,6 +31,8 @@ VariationRatio, ) from ..plotting_utils import CalibrationPlot, plot_hist +from ..post_processing import TemperatureScaler +from ..transforms import Mixup, MixupIO, RegMixup, WarpingMixup class ClassificationSingle(pl.LightningModule): @@ -59,11 +61,17 @@ def __init__( loss: Type[nn.Module], optimization_procedure: Any, format_batch_fn: nn.Module = nn.Identity(), + mixtype: str = "erm", + mixmode: str = "elem", + dist_sim: str = "emb", + kernel_tau_max: float = 1.0, + kernel_tau_std: float = 0.5, mixup_alpha: float = 0, cutmix_alpha: float = 0, evaluate_ood: bool = False, use_entropy: bool = False, use_logits: bool = False, + calibration_set: Optional[Callable] = None, **kwargs, ) -> None: super().__init__() @@ -74,6 +82,7 @@ def __init__( "loss", "optimization_procedure", "format_batch_fn", + "calibration_set", ] ) @@ -85,6 +94,8 @@ def __init__( self.use_logits = use_logits self.use_entropy = use_entropy + self.calibration_set = calibration_set + self.binary_cls = num_classes == 1 self.model = model @@ -121,6 +132,9 @@ def __init__( self.val_cls_metrics = cls_metrics.clone(prefix="hp/val_") self.test_cls_metrics = cls_metrics.clone(prefix="hp/test_") + if self.calibration_set is not None: + self.ts_cls_metrics = cls_metrics.clone(prefix="hp/ts_") + self.test_entropy_id = Entropy() if self.evaluate_ood: @@ -140,12 +154,14 @@ def __init__( "Cutmix alpha and Mixup alpha must be positive." f"Got {mixup_alpha} and {cutmix_alpha}." ) - elif mixup_alpha > 0 or cutmix_alpha > 0: - self.mixup = Mixup( - mixup_alpha=mixup_alpha, cutmix_alpha=cutmix_alpha - ) - else: - self.mixup = lambda x, y: (x, y) + + self.mixtype = mixtype + self.mixmode = mixmode + self.dist_sim = dist_sim + + self.mixup = self.init_mixup( + mixup_alpha, cutmix_alpha, kernel_tau_max, kernel_tau_std + ) self.cal_plot = CalibrationPlot() @@ -190,13 +206,26 @@ def on_train_start(self) -> None: "hp/test_aupr": 0, "hp/test_auroc": 0, "hp/test_fpr95": 0, + "hp/ts_test_nll": 0, + "hp/ts_test_ece": 0, + "hp/ts_test_brier": 0, }, ) def training_step( self, batch: Tuple[Tensor, Tensor], batch_idx: int ) -> STEP_OUTPUT: - batch = self.mixup(*batch) + if self.mixtype == "kernel_warping": + if self.dist_sim == "emb": + with torch.no_grad(): + feats = self.model.feats_forward(batch[0]).detach() + + batch = self.mixup(*batch, feats) + elif self.dist_sim == "inp": + batch = self.mixup(*batch, batch[0]) + else: + batch = self.mixup(*batch) + inputs, targets = self.format_batch_fn(batch) if self.is_elbo: @@ -234,6 +263,16 @@ def validation_epoch_end( self.log_dict(self.val_cls_metrics.compute()) self.val_cls_metrics.reset() + def on_test_start(self) -> None: + if self.calibration_set is not None: + self.scaler = TemperatureScaler(device=self.device).fit( + model=self.model, calibration_set=self.calibration_set() + ) + self.cal_model = torch.nn.Sequential(self.model, self.scaler) + else: + self.scaler = None + self.cal_model = None + def test_step( self, batch: Tuple[Tensor, Tensor], @@ -258,6 +297,15 @@ def test_step( else: ood_scores = -confs + if ( + self.calibration_set is not None + and self.scaler is not None + and self.cal_model is not None + ): + cal_logits = self.cal_model(inputs) + cal_probs = F.softmax(cal_logits, dim=-1) + self.ts_cls_metrics.update(cal_probs, targets) + if dataloader_idx == 0: self.test_cls_metrics.update(probs, targets) self.test_entropy_id(probs) @@ -290,6 +338,14 @@ def test_epoch_end( ) self.test_cls_metrics.reset() + if ( + self.calibration_set is not None + and self.scaler is not None + and self.cal_model is not None + ): + self.log_dict(self.ts_cls_metrics.compute()) + self.ts_cls_metrics.reset() + if self.evaluate_ood: self.log_dict( self.test_ood_metrics.compute(), @@ -323,6 +379,49 @@ def test_epoch_end( "Likelihood Histogram", probs_fig ) + def init_mixup( + self, + mixup_alpha: float, + cutmix_alpha: float, + kernel_tau_max: float, + kernel_tau_std: float, + ) -> Callable: + if self.mixtype == "timm": + return timm_Mixup( + mixup_alpha=mixup_alpha, + cutmix_alpha=cutmix_alpha, + mode=self.mixmode, + num_classes=self.num_classes, + ) + elif self.mixtype == "mixup": + return Mixup( + alpha=mixup_alpha, + mode=self.mixmode, + num_classes=self.num_classes, + ) + elif self.mixtype == "mixup_io": + return MixupIO( + alpha=mixup_alpha, + mode=self.mixmode, + num_classes=self.num_classes, + ) + elif self.mixtype == "regmixup": + return RegMixup( + alpha=mixup_alpha, + mode=self.mixmode, + num_classes=self.num_classes, + ) + elif self.mixtype == "kernel_warping": + return WarpingMixup( + alpha=mixup_alpha, + mode=self.mixmode, + num_classes=self.num_classes, + apply_kernel=True, + tau_max=kernel_tau_max, + tau_std=kernel_tau_std, + ) + return lambda x, y: (x, y) + @staticmethod def add_model_specific_args( parent_parser: ArgumentParser, @@ -335,10 +434,10 @@ def add_model_specific_args( - ``--logits``: sets :attr:`use_logits` to ``True``. """ parent_parser.add_argument( - "--mixup", dest="mixup_alpha", type=float, default=0 + "--mixup_alpha", dest="mixup_alpha", type=float, default=0 ) parent_parser.add_argument( - "--cutmix", dest="cutmix_alpha", type=float, default=0 + "--cutmix_alpha", dest="cutmix_alpha", type=float, default=0 ) parent_parser.add_argument( "--entropy", dest="use_entropy", action="store_true" @@ -346,6 +445,22 @@ def add_model_specific_args( parent_parser.add_argument( "--logits", dest="use_logits", action="store_true" ) + parent_parser.add_argument( + "--mixtype", dest="mixtype", type=str, default="erm" + ) + parent_parser.add_argument( + "--mixmode", dest="mixmode", type=str, default="elem" + ) + parent_parser.add_argument( + "--dist_sim", dest="dist_sim", type=str, default="emb" + ) + parent_parser.add_argument( + "--kernel_tau_max", dest="kernel_tau_max", type=float, default=1.0 + ) + parent_parser.add_argument( + "--kernel_tau_std", dest="kernel_tau_std", type=float, default=0.5 + ) + return parent_parser diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index 2913bf50..2595f4f9 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -1,5 +1,5 @@ from argparse import ArgumentParser, Namespace -from typing import Any, List, Literal, Optional, Tuple, Union +from typing import Any, List, Literal, Optional, Tuple, Type, Union import pytorch_lightning as pl import torch @@ -17,7 +17,7 @@ class RegressionSingle(pl.LightningModule): def __init__( self, model: nn.Module, - loss: nn.Module, + loss: Type[nn.Module], optimization_procedure: Any, dist_estimation: int, **kwargs, @@ -209,7 +209,7 @@ class RegressionEnsemble(RegressionSingle): def __init__( self, model: nn.Module, - loss: nn.Module, + loss: Type[nn.Module], optimization_procedure: Any, dist_estimation: int, num_estimators: int, diff --git a/torch_uncertainty/transforms/__init__.py b/torch_uncertainty/transforms/__init__.py index 3c270d1c..a55bdad6 100644 --- a/torch_uncertainty/transforms/__init__.py +++ b/torch_uncertainty/transforms/__init__.py @@ -1,5 +1,6 @@ # ruff: noqa: F401 from .cutout import Cutout +from .mixup import Mixup, MixupIO, RegMixup, WarpingMixup from .transforms import ( AutoContrast, Brightness, diff --git a/torch_uncertainty/transforms/cutout.py b/torch_uncertainty/transforms/cutout.py index f84243b6..98865547 100644 --- a/torch_uncertainty/transforms/cutout.py +++ b/torch_uncertainty/transforms/cutout.py @@ -1,8 +1,7 @@ +import numpy as np import torch from torch import nn -import numpy as np - class Cutout(nn.Module): """Cutout augmentation class. diff --git a/torch_uncertainty/transforms/mixup.py b/torch_uncertainty/transforms/mixup.py new file mode 100644 index 00000000..e39d6190 --- /dev/null +++ b/torch_uncertainty/transforms/mixup.py @@ -0,0 +1,228 @@ +from typing import Tuple + +import numpy as np +import scipy +import torch +import torch.nn.functional as F +from torch import Tensor + + +def beta_warping(x, alpha_cdf: float = 1.0, eps: float = 1e-12) -> float: + return scipy.stats.beta.cdf(x, a=alpha_cdf + eps, b=alpha_cdf + eps) + + +def sim_gauss_kernel(dist, tau_max: float = 1.0, tau_std: float = 0.5) -> float: + dist_rate = tau_max * np.exp( + -(dist - 1) / (np.mean(dist) * 2 * tau_std * tau_std) + ) + return 1 / (dist_rate + 1e-12) + + +# def tensor_linspace(start: Tensor, stop: Tensor, num: int): +# """ +# Creates a tensor of shape [num, *start.shape] whose values are evenly +# spaced from start to end, inclusive. +# Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch. +# """ +# # create a tensor of 'num' steps from 0 to 1 +# steps = torch.arange(num, dtype=torch.float32, device=start.device) / ( +# num - 1 +# ) + +# # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] +# # to allow for broadcastings +# # using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here +# # but torchscript +# # "cannot statically infer the expected size of a list in this contex", +# # hence the code below +# for i in range(start.ndim): +# steps = steps.unsqueeze(-1) + +# # the output starts at 'start' and increments until 'stop' in each dimension +# out = start[None] + steps * (stop - start)[None] + +# return out + + +# def torch_beta_cdf( +# x: Tensor, c1: Tensor | float, c2: Tensor | float, npts=100, eps=1e-12 +# ): +# if isinstance(c1, float): +# if c1 == c2: +# c1 = Tensor([c1], device=x.device) +# c2 = c1 +# else: +# c1 = Tensor([c1], device=x.device) +# if isinstance(c2, float): +# c2 = Tensor([c2], device=x.device) +# bt = torch.distributions.Beta(c1, c2) + +# if isinstance(x, float): +# x = Tensor(x) + +# X = tensor_linspace(torch.zeros_like(x) + eps, x, npts) +# return torch.trapezoid(bt.log_prob(X).exp(), X, dim=0) + + +# def torch_beta_warping( +# x: Tensor, alpha_cdf: float | Tensor = 1.0, eps=1e-12, npts=100 +# ): +# return torch_beta_cdf( +# x=x, c1=alpha_cdf + eps, c2=alpha_cdf + eps, npts=npts, eps=eps +# ) + + +# def torch_sim_gauss_kernel(dist: Tensor, tau_max=1.0, tau_std=0.5): +# dist_rate = tau_max * torch.exp( +# -(dist - 1) / (torch.mean(dist) * 2 * tau_std * tau_std) +# ) + +# return 1 / (dist_rate + 1e-12) + + +class AbstractMixup: + def __init__( + self, alpha: float = 1.0, mode: str = "batch", num_classes: int = 1000 + ) -> None: + self.alpha = alpha + self.num_classes = num_classes + self.mode = mode + + def _get_params(self, batch_size: int, device: torch.device): + if self.mode == "batch": + lam = np.random.beta(self.alpha, self.alpha) + else: + lam = torch.as_tensor( + np.random.beta(self.alpha, self.alpha, batch_size), + device=device, + ) + index = torch.randperm(batch_size, device=device) + return lam, index + + def _linear_mixing( + self, + lam: Tensor | float, + inp: Tensor, + index: Tensor, + ) -> Tensor: + if isinstance(lam, Tensor): + lam = lam.view(-1, *[1 for _ in range(inp.ndim - 1)]).float() + + return lam * inp + (1 - lam) * inp[index, :] + + def _mix_target( + self, + lam: Tensor | float, + target: Tensor, + index: Tensor, + ) -> Tensor: + y1 = F.one_hot(target, self.num_classes) + y2 = F.one_hot(target[index], self.num_classes) + if isinstance(lam, Tensor): + lam = lam.view(-1, *[1 for _ in range(y1.ndim - 1)]).float() + + return lam * y1 + (1 - lam) * y2 + + def __call__(self, x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: + raise NotImplementedError + + +class Mixup(AbstractMixup): + """Original Mixup method from Zhang et al., + "mixup: Beyond Empirical Risk Minimization" (ICLR 2021) + http://arxiv.org/abs/1710.09412 + """ + + def __call__(self, x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: + lam, index = self._get_params(x.size()[0], x.device) + mixed_x = self._linear_mixing(lam, x, index) + mixed_y = self._mix_target(lam, y, index) + return mixed_x, mixed_y + + +class MixupIO(AbstractMixup): + """Mixup on inputs only with targets unchanged, from Wang et al., + "On the Pitfall of Mixup for Uncertainty Calibration" (CVPR 2023) + https://openaccess.thecvf.com/content/CVPR2023/papers/Wang_On_the_Pitfall_of_Mixup_for_Uncertainty_Calibration_CVPR_2023_paper.pdf + """ + + def __call__(self, x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: + lam, index = self._get_params(x.size()[0], x.device) + + mixed_x = self._linear_mixing(lam, x, index) + + if self.mode == "batch": + mixed_y = self._mix_target(float(lam > 0.5), y, index) + else: + mixed_y = self._mix_target((lam > 0.5).float(), y, index) + + return mixed_x, mixed_y + + +class RegMixup(AbstractMixup): + """RegMixup method from Pinto et al., + "RegMixup: Mixup as a Regularizer Can Surprisingly Improve Accuracy and Out Distribution Robustness" (NeurIPS 2022) + https://arxiv.org/abs/2206.14502 + """ + + def __call__(self, x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: + lam, index = self._get_params(x.size()[0], x.device) + part_x = self._linear_mixing(lam, x, index) + part_y = self._mix_target(lam, y, index) + mixed_x = torch.cat([x, part_x], dim=0) + mixed_y = torch.cat([F.one_hot(y, self.num_classes), part_y], dim=0) + return mixed_x, mixed_y + + +class WarpingMixup(AbstractMixup): + """Kernel Warping Mixup method from Bouniot et al., + "Tailoring Mixup to Data using Kernel Warping functions" (2023) + https://arxiv.org/abs/2311.01434 + """ + + def __init__( + self, + alpha: float = 1.0, + mode: str = "batch", + num_classes: int = 1000, + apply_kernel: bool = True, + tau_max: float = 1.0, + tau_std: float = 0.5, + ) -> None: + super().__init__(alpha, mode, num_classes) + self.apply_kernel = apply_kernel + self.tau_max = tau_max + self.tau_std = tau_std + + def _get_params(self, batch_size: int, device: torch.device): + if self.mode == "batch": + lam = np.random.beta(self.alpha, self.alpha) + else: + lam = np.random.beta(self.alpha, self.alpha, batch_size) + + index = torch.randperm(batch_size, device=device) + return lam, index + + def __call__( + self, + x: Tensor, + y: Tensor, + feats: Tensor, + warp_param=1.0, + ) -> Tuple[Tensor, Tensor]: + lam, index = self._get_params(x.size()[0], x.device) + + if self.apply_kernel: + l2_dist = ( + (feats - feats[index]) + .pow(2) + .sum([i for i in range(len(feats.size())) if i > 0]) + .cpu() + .numpy() + ) + warp_param = sim_gauss_kernel(l2_dist, self.tau_max, self.tau_std) + + k_lam = torch.as_tensor(beta_warping(lam, warp_param), device=x.device) + mixed_x = self._linear_mixing(k_lam, x, index) + mixed_y = self._mix_target(k_lam, y, index) + return mixed_x, mixed_y diff --git a/torch_uncertainty/transforms/pixmix.py b/torch_uncertainty/transforms/pixmix.py index d9d96941..56a0e7b3 100644 --- a/torch_uncertainty/transforms/pixmix.py +++ b/torch_uncertainty/transforms/pixmix.py @@ -1,9 +1,9 @@ """ Code adapted from PixMix' paper. """ +import numpy as np from PIL import Image from torch import nn -import numpy as np from torch_uncertainty.transforms import Shear, Translate, augmentations diff --git a/torch_uncertainty/transforms/transforms.py b/torch_uncertainty/transforms/transforms.py index 950ff94a..03550328 100644 --- a/torch_uncertainty/transforms/transforms.py +++ b/torch_uncertainty/transforms/transforms.py @@ -1,13 +1,12 @@ from typing import List, Optional, Tuple, Union +import numpy as np import torch import torchvision.transforms.functional as F from einops import rearrange from PIL import Image, ImageEnhance from torch import Tensor, nn -import numpy as np - class AutoContrast(nn.Module): pixmix_max_level = None diff --git a/torch_uncertainty/utils/__init__.py b/torch_uncertainty/utils/__init__.py index 1b739c7b..e6b70312 100644 --- a/torch_uncertainty/utils/__init__.py +++ b/torch_uncertainty/utils/__init__.py @@ -1,3 +1,4 @@ # ruff: noqa: F401 from .checkpoints import get_version from .hub import load_hf +from .misc import csv_writer diff --git a/torch_uncertainty/utils/misc.py b/torch_uncertainty/utils/misc.py new file mode 100644 index 00000000..19feab91 --- /dev/null +++ b/torch_uncertainty/utils/misc.py @@ -0,0 +1,19 @@ +import csv + + +def csv_writer(path, dic): + # Check if the file already exists + if path.is_file(): + append_mode = True + rw_mode = "a" + else: + append_mode = False + rw_mode = "w" + + # Write dic + with open(path, rw_mode) as csvfile: + writer = csv.writer(csvfile, delimiter=",") + # Do not write header in append mode + if append_mode is False: + writer.writerow(dic.keys()) + writer.writerow([f"{elem:.4f}" for elem in dic.values()])