Skip to content

Commit

Permalink
Lint
Browse files Browse the repository at this point in the history
  • Loading branch information
Boyan-MILANOV committed Oct 4, 2024
1 parent a0ee16b commit 2673994
Show file tree
Hide file tree
Showing 7 changed files with 240 additions and 147 deletions.
123 changes: 73 additions & 50 deletions pickle_scanning_benchmark/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,48 @@
from dataclasses import dataclass, field
import json
from typing import Dict
from pathlib import Path
import os

# To run model unpickler
import pickle
import pickletools
import random
from fickling.fickle import Pickled
from fickling.analysis import check_safety, Severity, UnsafeImportsML, BadCalls, Analyzer
from fickling.ml import MLAllowlist
from fickling.pytorch import PyTorchModelWrapper
import sys
import logger
from typing import Optional, Callable
import pickletools
import os
from modelscan.modelscan import ModelScan
import picklescan.scanner as ps_scanner
import traceback
from model_unpickler import SafeUnpickler
import zipfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Dict, Optional

# To run model unpickler
import pickle
import logger
import picklescan.scanner as ps_scanner
import torch._C
from model_unpickler import SafeUnpickler
from modelscan.modelscan import ModelScan

from fickling.analysis import Analyzer, BadCalls, Severity, UnsafeImportsML, check_safety
from fickling.fickle import Pickled
from fickling.ml import MLAllowlist
from fickling.pytorch import PyTorchModelWrapper


class ModelUnpickler(SafeUnpickler):
def persistent_load(self, pid):
return None


setattr(pickle, "Unpickler", SafeUnpickler)

# Logging
ps_scanner._log.propagate = False
DEVNULL = open(os.devnull,"w")
DEVNULL = open(os.devnull, "w")

class InvalidPickleException(Exception):
pass

# TODO(boyan): do this when downloading the files
def is_valid(filepath, filetype):
with open(filepath, "rb") as f:
if f.read(100).startswith(b"Access to model"):
# HF access denied...
return False

if filetype == "pickle":
with open(filepath, "rb") as f:
try:
Expand All @@ -53,21 +56,24 @@ def is_valid(filepath, filetype):
else:
raise Exception("Unsupported file type")


def load_index(filepath):
with open(filepath, "r") as f:
with open(filepath) as f:
return json.load(f)


def run_fickling(filepath, filetype):
analysis = [
MLAllowlist(), # Import non standard non whitelisted stuff
UnsafeImportsML(), # Importing from unsafe modules
BadCalls(), # Overtly bad calls to built-in functions
MLAllowlist(), # Import non standard non whitelisted stuff
UnsafeImportsML(), # Importing from unsafe modules
BadCalls(), # Overtly bad calls to built-in functions
]
if filetype == "pickle":
return run_fickling_pickle(filepath, analysis)
elif filetype == "pytorch":
return run_fickling_pytorch(filepath, analysis)


def run_fickling_pickle(filepath, analysis) -> bool:
"""Return true if the file is considered safe"""
with open(filepath, "rb") as f:
Expand All @@ -76,6 +82,7 @@ def run_fickling_pickle(filepath, analysis) -> bool:
print(res)
return res.severity == Severity.LIKELY_SAFE


def run_fickling_pytorch(filepath, analysis) -> bool:
wrapper = PyTorchModelWrapper(filepath)
res = check_safety(wrapper.pickled, analyzer=Analyzer(analysis))
Expand All @@ -90,43 +97,45 @@ def run_modelscan(filepath, filetype):
return False
return True


def run_modelunpickler(filepath, filetype):
# print(filepath)
try:
if filetype == "pickle":
with open(filepath, "rb") as f:
data = ModelUnpickler(f).load()
ModelUnpickler(f).load()
elif filetype == "pytorch":
torch.load(filepath, map_location=torch.device('cpu'))
torch.load(filepath, map_location=torch.device("cpu"))
except (pickle.UnpicklingError, AttributeError) as e:
print(e)
return False
return True


def run_picklescan(filepath, filetype):
results = ps_scanner.scan_file_path(filepath)
if results.scan_err:
raise Exception("Failed to analyze file with picklescan. res.scan_err = True")
if results.issues_count == 0:
return True # Safe
return True # Safe
else:
return False # Unsafe
return False # Unsafe


def _analyze_file(
toolname: str,
run_tool_func: Callable,
fileinfo: Dict,
results: Dict[str, "BenchmarkResults"],
expected_scan_result: bool, # True for clean files, False for malicious files
expected_scan_result: bool, # True for clean files, False for malicious files
payload: Optional[str] = None,
):
logger.info(f"Running {toolname} on {fileinfo['file']}")
# Run tool
if expected_scan_result == True:
if expected_scan_result:
try:
clean = run_tool_func(fileinfo["file"], fileinfo["type"])
if clean:
if clean:
results.tools[toolname].add_tn()
else:
results.tools[toolname].add_fp()
Expand All @@ -140,9 +149,11 @@ def _analyze_file(
else:
try:
clean = run_tool_func(fileinfo["file"], fileinfo["type"])
if clean:
if clean:
results.tools[toolname].add_fn(payload=payload)
logger.warning(f"Malicious file missed by {toolname}: {fileinfo['file']}. Payload was: {payload}")
logger.warning(
f"Malicious file missed by {toolname}: {fileinfo['file']}. Payload was: {payload}"
)
else:
results.tools[toolname].add_tp()
except KeyboardInterrupt as e:
Expand All @@ -152,14 +163,21 @@ def _analyze_file(
logger.error(f"Failed to analyze file: {e}")
results.tools[toolname].nb_failed_files += 1

def run_benchmark(clean_dataset_dir: Path, malicious_dataset_dir: Path, tools: dict, n=10000, clean_to_malicious_ratio=2.0):

def run_benchmark(
clean_dataset_dir: Path,
malicious_dataset_dir: Path,
tools: dict,
n=10000,
clean_to_malicious_ratio=2.0,
):
# Load file indexes
clean_index = load_index(clean_dataset_dir / "index.json")
malicious_index = load_index(malicious_dataset_dir / "index.json")

# Select files for the benchmark
# Get ratio
nb_malicious_files = __builtins__.round(n/(1+clean_to_malicious_ratio))
nb_malicious_files = __builtins__.round(n / (1 + clean_to_malicious_ratio))
nb_clean_files = n - nb_malicious_files
# Don't get more files that we actually have in the datasets
nb_malicious_files = __builtins__.min(nb_malicious_files, len(malicious_index))
Expand All @@ -186,8 +204,10 @@ def run_benchmark(clean_dataset_dir: Path, malicious_dataset_dir: Path, tools: d
continue
results.nb_malicious_files += 1
for toolname, runtool in tools.items():
_analyze_file(toolname, runtool, f, results, expected_scan_result=False, payload=f["payload"])
except KeyboardInterrupt as e:
_analyze_file(
toolname, runtool, f, results, expected_scan_result=False, payload=f["payload"]
)
except KeyboardInterrupt:
pass

# Print results
Expand All @@ -197,15 +217,15 @@ def run_benchmark(clean_dataset_dir: Path, malicious_dataset_dir: Path, tools: d
@dataclass
class ToolResults:
# Overall results
tn_clean: int = 0 # Clean true negatives (good)
fp_clean: int = 0 # Clean false positive (bad)
fn_malicious: int = 0 # Malicious false negative (bad)
tp_malicious: int = 0 # malicious true positive (good)
tn_clean: int = 0 # Clean true negatives (good)
fp_clean: int = 0 # Clean false positive (bad)
fn_malicious: int = 0 # Malicious false negative (bad)
tp_malicious: int = 0 # malicious true positive (good)

nb_scanned_files: int = 0 # Files scanned without errors
nb_failed_files: int = 0 # The tool failed to scan the files
nb_scanned_files: int = 0 # Files scanned without errors
nb_failed_files: int = 0 # The tool failed to scan the files

fn_payload_types: Dict[str, int] = field(default_factory=dict) # <payload type> --> how many
fn_payload_types: Dict[str, int] = field(default_factory=dict) # <payload type> --> how many

@property
def total_files(self) -> int:
Expand Down Expand Up @@ -233,7 +253,10 @@ def add_tp(self, n=1):
self.nb_scanned_files += n

def sanity_check(self):
assert self.tn_clean + self.tp_malicious + self.fn_malicious + self.fp_clean == self.nb_scanned_files
assert (
self.tn_clean + self.tp_malicious + self.fn_malicious + self.fp_clean
== self.nb_scanned_files
)

def to_str(self, bench_res: "BenchmarkResults"):
tn_rate = self.tn_clean / bench_res.nb_clean_files
Expand All @@ -251,12 +274,13 @@ def to_str(self, bench_res: "BenchmarkResults"):

return res


@dataclass
class BenchmarkResults:
# Overall files
nb_clean_files: int = 0 # Total seen clean files
nb_malicious_files: int = 0 # Total seen malicious files
nb_invalid_files: int = 0 # Files where even pickletools fail
nb_clean_files: int = 0 # Total seen clean files
nb_malicious_files: int = 0 # Total seen malicious files
nb_invalid_files: int = 0 # Files where even pickletools fail

tools: Dict[str, ToolResults] = field(default_factory=dict)

Expand All @@ -272,7 +296,6 @@ def total_files(self):
return self.nb_clean_files + self.nb_malicious_files + self.nb_invalid_files

def __str__(self):
invalid_files_rate = self.nb_invalid_files / self.total_files
res = f"""
### Benchmark results
Expand All @@ -295,4 +318,4 @@ def __str__(self):
}
clean_dataset = Path(sys.argv[1])
malicious_dataset = Path(sys.argv[2])
run_benchmark(clean_dataset, malicious_dataset, tools)
run_benchmark(clean_dataset, malicious_dataset, tools)
19 changes: 11 additions & 8 deletions pickle_scanning_benchmark/dataset_stats.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import csv
import json
from pathlib import Path
import sys
from dataclasses import dataclass, field
import os
from pathlib import Path

from huggingface_hub import HfApi

from fickling.fickle import Pickled
from fickling.pytorch import PyTorchModelWrapper
import sys
import csv
from huggingface_hub import HfApi

api = HfApi()


@dataclass
class Stats:
nb_files: int = 0
Expand All @@ -26,7 +28,6 @@ def add(self, file):
self.nb_files += 1
self._record_project(file["project"])
self._record_file_type(file)


def _record_file_type(self, file):
if file["type"] not in self.file_types:
Expand Down Expand Up @@ -54,7 +55,7 @@ def _get_pickled(self, file):
return Pickled.load(f)
elif file["type"] == "pytorch":
return PyTorchModelWrapper(file["file"]).pickled
except:
except Exception:
return None

def _record_project(self, project):
Expand All @@ -78,6 +79,7 @@ def dump_project_downloads(self):
w.writeheader()
w.writerow(self.projects)


def get_stats(dataset_dir: Path):
with open(dataset_dir / "index.json", "rb") as f:
index = json.load(f)
Expand All @@ -88,11 +90,12 @@ def get_stats(dataset_dir: Path):
stats.finalise()
return stats


if __name__ == "__main__":
stats = get_stats(Path(sys.argv[1]))
stats.dump_imports()
stats.dump_project_downloads()
print(stats)
print("Total project downloads", sum(stats.projects.values()))
print("Avg project download", sum(stats.projects.values()) / len(stats.projects))
print("Nb projects", len(stats.projects))
print("Nb projects", len(stats.projects))
Loading

0 comments on commit 2673994

Please sign in to comment.