-
-
Notifications
You must be signed in to change notification settings - Fork 627
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Fix pytorch weights check * B614: Fix PyTorch plugin to handle weights_only parameter correctly The PyTorch plugin (B614) has been updated to properly handle the weights_only parameter in torch.load calls. When weights_only=True is specified, PyTorch will only deserialize known safe types, making the operation more secure. I also removed torch.save as there is no certain insecure element as such, saving any file or artifact requires consideration of what it is you are saving. Changes: - Update plugin to only check torch.load calls (not torch.save) - Fix weights_only check to handle both string and boolean True values - Remove map_location check as it doesn't affect security - Update example file to demonstrate both safe and unsafe cases - Update plugin documentation to mention weights_only as a safe alternative The plugin now correctly identifies unsafe torch.load calls while allowing safe usage with weights_only=True to pass without warning. Fixes: #1224 * Fix E501 line too long * Rename files to new test scope * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update doc/source/plugins/b614_pytorch_load.rst Co-authored-by: Eric Brown <[email protected]> * Update pytorch_load.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Brown <[email protected]>
- Loading branch information
1 parent
def123a
commit 8ff25e0
Showing
7 changed files
with
63 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
------------------ | ||
B614: pytorch_load | ||
------------------ | ||
|
||
.. automodule:: bandit.plugins.pytorch_load |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import torch | ||
import torchvision.models as models | ||
|
||
# Example of saving a model | ||
model = models.resnet18(pretrained=True) | ||
torch.save(model.state_dict(), 'model_weights.pth') | ||
|
||
# Example of loading the model weights in an insecure way (should trigger B614) | ||
loaded_model = models.resnet18() | ||
loaded_model.load_state_dict(torch.load('model_weights.pth')) | ||
|
||
# Example of loading with weights_only=True (should NOT trigger B614) | ||
safe_model = models.resnet18() | ||
safe_model.load_state_dict(torch.load('model_weights.pth', weights_only=True)) | ||
|
||
# Example of loading with weights_only=False (should trigger B614) | ||
unsafe_model = models.resnet18() | ||
unsafe_model.load_state_dict(torch.load('model_weights.pth', weights_only=False)) | ||
|
||
# Example of loading with map_location but no weights_only (should trigger B614) | ||
cpu_model = models.resnet18() | ||
cpu_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu')) | ||
|
||
# Example of loading with both map_location and weights_only=True (should NOT trigger B614) | ||
safe_cpu_model = models.resnet18() | ||
safe_cpu_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu', weights_only=True)) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters