Skip to content

Commit

Permalink
Merge commit from fork
Browse files Browse the repository at this point in the history
  • Loading branch information
madgetr authored Mar 3, 2025
1 parent 13bb359 commit baf03fa
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 5 deletions.
15 changes: 13 additions & 2 deletions src/picklescan/scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def __str__(self) -> str:
_pytorch_file_extensions = {".bin", ".pt", ".pth", ".ckpt"}
_pickle_file_extensions = {".pkl", ".pickle", ".joblib", ".dat", ".data"}
_zip_file_extensions = {".zip", ".npz", ".7z"}
_pickle_magic_bytes = {b"\x80\x00", b"\x80\x01", b"\x80\x02", b"\x80\x03", b"\x80\x04", b"\x80\x05"}


def _is_7z_file(f: IO[bytes]) -> bool:
Expand Down Expand Up @@ -349,20 +350,30 @@ def scan_7z_bytes(data: IO[bytes], file_id) -> ScanResult:

return result

def get_magic_bytes_from_zipfile(zip: zipfile.ZipFile, num_bytes=8):
magic_bytes = {}
for file_info in zip.infolist():
with zip.open(file_info.filename) as f:
magic_bytes[file_info.filename] = f.read(num_bytes)

return magic_bytes


def scan_zip_bytes(data: IO[bytes], file_id) -> ScanResult:
result = ScanResult([])

with zipfile.ZipFile(data, "r") as zip:
magic_bytes = get_magic_bytes_from_zipfile(zip)
file_names = zip.namelist()
_log.debug("Files in zip archive %s: %s", file_id, file_names)
for file_name in file_names:
magic_number = magic_bytes.get(file_name, b"")
file_ext = os.path.splitext(file_name)[1]
if file_ext in _pickle_file_extensions:
if file_ext in _pickle_file_extensions or any(magic_number.startswith(mn) for mn in _pickle_magic_bytes):
_log.debug("Scanning file %s in zip archive %s", file_name, file_id)
with zip.open(file_name, "r") as file:
result.merge(scan_pickle_bytes(file, f"{file_id}:{file_name}"))
elif file_ext in _numpy_file_extensions:
elif file_ext in _numpy_file_extensions or magic_number.startswith(b"\x93NUMPY"):
_log.debug("Scanning file %s in zip archive %s", file_name, file_id)
with zip.open(file_name, "r") as file:
result.merge(scan_numpy(file, f"{file_id}:{file_name}"))
Expand Down
Binary file added tests/data/malicious1_wrong_ext.zip
Binary file not shown.
19 changes: 16 additions & 3 deletions tests/test_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,12 @@ def initialize_pickle_files():
pickle.dumps(Malicious1(), protocol=4),
)

initialize_zip_file(
f"{_root_path}/data/malicious1_wrong_ext.zip",
"data.txt", # Pickle file with a non-standard extension
pickle.dumps(Malicious1(), protocol=4),
)

# Fake PyTorch file (PNG file format) simulating https://huggingface.co/RectalWorm/loras_new/blob/main/Owl_Mage_no_background.pt
initialize_data_file(f"{_root_path}/data/bad_pytorch.pt", b"\211PNG\r\n\032\n")

Expand Down Expand Up @@ -593,6 +599,12 @@ def test_scan_file_path():
compare_scan_results(
scan_file_path(f"{_root_path}/data/malicious1.zip"), malicious1
)
compare_scan_results(
scan_file_path(f"{_root_path}/data/malicious1.7z"), malicious1
)
compare_scan_results(
scan_file_path(f"{_root_path}/data/malicious1_wrong_ext.zip"), malicious1
)

malicious2 = ScanResult([Global("posix", "system", SafetyLevel.Dangerous)], 1, 1, 1)
compare_scan_results(
Expand Down Expand Up @@ -772,10 +784,11 @@ def test_scan_directory_path():
Global("builtins", "exec", SafetyLevel.Dangerous),
Global("builtins", "eval", SafetyLevel.Dangerous),
Global("pip", "main", SafetyLevel.Dangerous),
Global("builtins", "eval", SafetyLevel.Dangerous),
],
scanned_files=33,
issues_count=33,
infected_files=28,
scanned_files=34,
issues_count=34,
infected_files=29,
scan_err=True,
)
compare_scan_results(scan_directory_path(f"{_root_path}/data/"), sr)
Expand Down

0 comments on commit baf03fa

Please sign in to comment.