Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Torch an optional dependency #95

Merged
merged 7 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ Both the library and command line utility can be installed through pip:
python -m pip install fickling
```

PyTorch is an optional dependency of Fickling. Therefore, in order to use Fickling's `pytorch`
and `polyglot` modules, you should run:

```bash
python -m pip install fickling[torch]
```

## Malicious file detection

Fickling can seamlessly be integrated into your codebase to detect and halt the loading of malicious
Expand Down Expand Up @@ -169,10 +176,10 @@ following PyTorch file formats:

* **PyTorch v0.1.1**: Tar file with sys_info, pickle, storages, and tensors
* **PyTorch v0.1.10**: Stacked pickle files
* **TorchScript v1.0**: ZIP file with model.json and constants.pkl (a JSON file and a pickle file)
* **TorchScript v1.1**: ZIP file with model.json and attribute.pkl (a JSON file and a pickle file)
* **TorchScript v1.3**: ZIP file with data.pkl and constants.pkl (2 pickle files)
* **TorchScript v1.4**: ZIP file with data.pkl, constants.pkl, and version (2 pickle files and a folder)
* **TorchScript v1.0**: ZIP file with model.json
* **TorchScript v1.1**: ZIP file with model.json and attributes.pkl
* **TorchScript v1.3**: ZIP file with data.pkl and constants.pkl
* **TorchScript v1.4**: ZIP file with data.pkl, constants.pkl, and version set at 2 or higher (2 pickle files and a folder)
* **PyTorch v1.3**: ZIP file containing data.pkl (1 pickle file)
* **PyTorch model archive format[ZIP]**: ZIP file that includes Python code files and pickle files

Expand Down
25 changes: 16 additions & 9 deletions fickling/polyglot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import tarfile
import zipfile

from torch.serialization import _is_zipfile

from fickling.fickle import Pickled, StackedPickle

"""
Expand All @@ -14,10 +12,10 @@
We currently support the following PyTorch file formats:
• PyTorch v0.1.1: Tar file with sys_info, pickle, storages, and tensors
• PyTorch v0.1.10: Stacked pickle files
• TorchScript v1.0: ZIP file with model.json and constants.pkl (a JSON file and a pickle file)
• TorchScript v1.1: ZIP file with model.json and attribute.pkl (a JSON file and a pickle file)
• TorchScript v1.0: ZIP file with model.json
• TorchScript v1.1: ZIP file with model.json and attributes.pkl (a JSON file and a pickle file)
• TorchScript v1.3: ZIP file with data.pkl and constants.pkl (2 pickle files)
• TorchScript v1.4: ZIP file with data.pkl, constants.pkl, and version (2 pickle files and a folder)
• TorchScript v1.4: ZIP file with data.pkl, constants.pkl, and version set at 2 or higher
• PyTorch v1.3: ZIP file containing data.pkl (1 pickle file)
• PyTorch model archive format[ZIP]: ZIP file that includes Python code files and pickle files

Expand All @@ -27,6 +25,15 @@
Another useful reference is https://github.com/lutzroeder/netron/blob/main/source/pytorch.js.
"""

try:
from torch.serialization import _is_zipfile
except ModuleNotFoundError:
raise ImportError(
"The 'torch' module is required for this functionality."
"PyTorch is now an optional dependency in Fickling."
"Please use `pip install fickling[torch]`"
)


def check_and_find_in_zip(
zip_path, file_name_or_extension, return_path=False, check_extension=False
Expand Down Expand Up @@ -99,15 +106,15 @@ def find_file_properties(file_path, print_properties=False):
"has_data_pkl": False,
"has_version": False,
"has_model_json": False,
"has_attribute_pkl": False,
"has_attributes_pkl": False,
}
if is_torch_zip:
torch_zip_checks = [
"data.pkl",
"constants.pkl",
"version",
"model.json",
"attribute.pkl",
"attributes.pkl",
]
torch_zip_results = {
f"has_{'_'.join(f.split('.'))}": check_and_find_in_zip(
Expand Down Expand Up @@ -163,7 +170,7 @@ def check_for_corruption(properties):
if properties["is_torch_zip"]:
if (
properties["has_model_json"]
and not properties["has_attribute_pkl"]
and not properties["has_attributes_pkl"]
and not properties["has_constants_pkl"]
):
corrupted = True
Expand All @@ -188,7 +195,7 @@ def identify_pytorch_file_format(file, print_properties=False, print_results=Fal
(["has_data_pkl", "has_constants_pkl", "has_version"], "TorchScript v1.4"),
(["has_data_pkl", "has_constants_pkl"], "TorchScript v1.3"),
(["has_model_json", "has_constants_pkl"], "TorchScript v1.0"),
(["has_model_json", "has_attribute_pkl"], "TorchScript v1.1"),
(["has_model_json", "has_attributes_pkl"], "TorchScript v1.1"),
(["has_data_pkl"], "PyTorch v1.3"),
]
formats = [
Expand Down
11 changes: 9 additions & 2 deletions fickling/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,18 @@
from pathlib import Path
from typing import Optional, Set

import torch

import fickling.polyglot
from fickling.fickle import Pickled

try:
import torch
except ModuleNotFoundError:
raise ImportError(
"The 'torch' module is required for this functionality."
"PyTorch is now an optional dependency in Fickling."
"Please use `pip install fickling[torch]`"
)


class BaseInjection(torch.nn.Module):
# This class allows you to combine the payload and original model
Expand Down
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@ classifiers = [
"Programming Language :: Python :: 3 :: Only",
"Topic :: Utilities",
]
dependencies = ["astunparse ~= 1.6.3", "torch >= 2.1.0", "torchvision >= 0.16.1"]
dependencies = ["astunparse ~= 1.6.3"]
requires-python = ">=3.8"

[project.optional-dependencies]
torch = ["torch >= 2.1.0", "torchvision >= 0.16.1"]
lint = ["black", "mypy", "ruff"]
test = ["pytest", "pytest-cov", "coverage[toml]"]
dev = ["build", "fickling[lint,test]", "twine"]
test = ["pytest", "pytest-cov", "coverage[toml]", "torch >= 2.1.0", "torchvision >= 0.16.1"]
dev = ["build", "fickling[lint,test]", "twine", "torch >= 2.1.0", "torchvision >= 0.16.1"]
examples = ["numpy", "pytorchfi"]

[project.scripts]
Expand Down
8 changes: 4 additions & 4 deletions test/test_polyglot.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def test_v1_3_properties(self):
"has_constants_pkl": False,
"has_version": True,
"has_model_json": False,
"has_attribute_pkl": False,
"has_attributes_pkl": False,
}
self.assertEqual(properties, proper_result)

Expand All @@ -144,7 +144,7 @@ def test_legacy_pickle_properties(self):
"has_constants_pkl": False,
"has_version": True,
"has_model_json": False,
"has_attribute_pkl": False,
"has_attributes_pkl": False,
}
self.assertEqual(properties, proper_result)

Expand All @@ -160,7 +160,7 @@ def test_torchscript_properties(self):
"has_constants_pkl": True,
"has_version": True,
"has_model_json": False,
"has_attribute_pkl": False,
"has_attributes_pkl": False,
}
self.assertEqual(properties, proper_result)

Expand All @@ -176,7 +176,7 @@ def test_zip_properties(self):
"has_data_pkl": False,
"has_version": False,
"has_model_json": False,
"has_attribute_pkl": False,
"has_attributes_pkl": False,
}
self.assertEqual(properties, proper_result)

Expand Down
Loading