Skip to content

Commit

Permalink
Pytorch fix (#1231)
Browse files Browse the repository at this point in the history
* Fix pytorch weights check

* B614: Fix PyTorch plugin to handle weights_only parameter correctly

The PyTorch plugin (B614) has been updated to properly handle the weights_only
parameter in torch.load calls. When weights_only=True is specified, PyTorch will
only deserialize known safe types, making the operation more secure.

I also removed torch.save as there is no certain insecure element as
such, saving any file or artifact requires consideration of what it is
you are saving.

Changes:
- Update plugin to only check torch.load calls (not torch.save)
- Fix weights_only check to handle both string and boolean True values
- Remove map_location check as it doesn't affect security
- Update example file to demonstrate both safe and unsafe cases
- Update plugin documentation to mention weights_only as a safe alternative

The plugin now correctly identifies unsafe torch.load calls while allowing safe
usage with weights_only=True to pass without warning.

Fixes: #1224

* Fix  E501 line too long

* Rename files to new test scope

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update doc/source/plugins/b614_pytorch_load.rst

Co-authored-by: Eric Brown <[email protected]>

* Update pytorch_load.py

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Brown <[email protected]>
  • Loading branch information
3 people authored Feb 16, 2025
1 parent def123a commit 8ff25e0
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,26 @@
#
# SPDX-License-Identifier: Apache-2.0
r"""
==========================================
B614: Test for unsafe PyTorch load or save
==========================================
==================================
B614: Test for unsafe PyTorch load
==================================
This plugin checks for the use of `torch.load` and `torch.save`. Using
`torch.load` with untrusted data can lead to arbitrary code execution, and
improper use of `torch.save` might expose sensitive data or lead to data
corruption. A safe alternative is to use `torch.load` with the `safetensors`
library from hugingface, which provides a safe deserialization mechanism.
This plugin checks for unsafe use of `torch.load`. Using `torch.load` with
untrusted data can lead to arbitrary code execution. There are two safe
alternatives:
1. Use `torch.load` with `weights_only=True` where only tensor data is
extracted, and no arbitrary Python objects are deserialized
2. Use the `safetensors` library from huggingface, which provides a safe
deserialization mechanism
With `weights_only=True`, PyTorch enforces a strict type check, ensuring
that only torch.Tensor objects are loaded.
:Example:
.. code-block:: none
>> Issue: Use of unsafe PyTorch load or save
>> Issue: Use of unsafe PyTorch load
Severity: Medium Confidence: High
CWE: CWE-94 (https://cwe.mitre.org/data/definitions/94.html)
Location: examples/pytorch_load_save.py:8
Expand All @@ -42,12 +47,11 @@

@test.checks("Call")
@test.test_id("B614")
def pytorch_load_save(context):
def pytorch_load(context):
"""
This plugin checks for the use of `torch.load` and `torch.save`. Using
`torch.load` with untrusted data can lead to arbitrary code execution,
and improper use of `torch.save` might expose sensitive data or lead
to data corruption.
This plugin checks for unsafe use of `torch.load`. Using `torch.load`
with untrusted data can lead to arbitrary code execution. The safe
alternative is to use `weights_only=True` or the safetensors library.
"""
imported = context.is_module_imported_exact("torch")
qualname = context.call_function_name_qual
Expand All @@ -59,14 +63,18 @@ def pytorch_load_save(context):
if all(
[
"torch" in qualname_list,
func in ["load", "save"],
not context.check_call_arg_value("map_location", "cpu"),
func == "load",
]
):
# For torch.load, check if weights_only=True is specified
weights_only = context.get_call_arg_value("weights_only")
if weights_only == "True" or weights_only is True:
return

return bandit.Issue(
severity=bandit.MEDIUM,
confidence=bandit.HIGH,
text="Use of unsafe PyTorch load or save",
text="Use of unsafe PyTorch load",
cwe=issue.Cwe.DESERIALIZATION_OF_UNTRUSTED_DATA,
lineno=context.get_lineno_for_call_arg("load"),
)
5 changes: 5 additions & 0 deletions doc/source/plugins/b614_pytorch_load.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
------------------
B614: pytorch_load
------------------

.. automodule:: bandit.plugins.pytorch_load
5 changes: 0 additions & 5 deletions doc/source/plugins/b614_pytorch_load_save.rst

This file was deleted.

26 changes: 26 additions & 0 deletions examples/pytorch_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch
import torchvision.models as models

# Example of saving a model
model = models.resnet18(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')

# Example of loading the model weights in an insecure way (should trigger B614)
loaded_model = models.resnet18()
loaded_model.load_state_dict(torch.load('model_weights.pth'))

# Example of loading with weights_only=True (should NOT trigger B614)
safe_model = models.resnet18()
safe_model.load_state_dict(torch.load('model_weights.pth', weights_only=True))

# Example of loading with weights_only=False (should trigger B614)
unsafe_model = models.resnet18()
unsafe_model.load_state_dict(torch.load('model_weights.pth', weights_only=False))

# Example of loading with map_location but no weights_only (should trigger B614)
cpu_model = models.resnet18()
cpu_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu'))

# Example of loading with both map_location and weights_only=True (should NOT trigger B614)
safe_cpu_model = models.resnet18()
safe_cpu_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu', weights_only=True))
21 changes: 0 additions & 21 deletions examples/pytorch_load_save.py

This file was deleted.

4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ bandit.plugins =
#bandit/plugins/tarfile_unsafe_members.py
tarfile_unsafe_members = bandit.plugins.tarfile_unsafe_members:tarfile_unsafe_members

#bandit/plugins/pytorch_load_save.py
pytorch_load_save = bandit.plugins.pytorch_load_save:pytorch_load_save
#bandit/plugins/pytorch_load.py
pytorch_load = bandit.plugins.pytorch_load:pytorch_load

# bandit/plugins/trojansource.py
trojansource = bandit.plugins.trojansource:trojansource
Expand Down
10 changes: 5 additions & 5 deletions tests/functional/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,13 +872,13 @@ def test_tarfile_unsafe_members(self):
}
self.check_example("tarfile_extractall.py", expect)

def test_pytorch_load_save(self):
"""Test insecure usage of torch.load and torch.save."""
def test_pytorch_load(self):
"""Test insecure usage of torch.load."""
expect = {
"SEVERITY": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 4, "HIGH": 0},
"CONFIDENCE": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 0, "HIGH": 4},
"SEVERITY": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 3, "HIGH": 0},
"CONFIDENCE": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 0, "HIGH": 3},
}
self.check_example("pytorch_load_save.py", expect)
self.check_example("pytorch_load.py", expect)

def test_trojansource(self):
expect = {
Expand Down

0 comments on commit 8ff25e0

Please sign in to comment.