Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve performance with multiprocessing #60

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 99 additions & 36 deletions log4j-finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
import itertools
import collections
import fnmatch
import threading
import multiprocessing

from pathlib import Path

Expand Down Expand Up @@ -131,12 +133,12 @@ def iter_scandir(path, stats=None, exclude=None):

def scantree(path, stats=None, exclude=None):
"""Recursively yield DirEntry objects for given directory."""
exclude = exclude or []
exclude = exclude or []
try:
with os.scandir(path) as it:
for entry in it:
if any(fnmatch.fnmatch(entry.path, exclusion) for exclusion in exclude):
continue
continue
if entry.is_dir(follow_symlinks=False):
if stats is not None:
stats["directories"] += 1
Expand Down Expand Up @@ -228,8 +230,10 @@ def check_vulnerable(fobj, path_chain, stats, has_jndilookup=True):
md5sum = md5_digest(fobj)
first_path = bold(path_chain.pop(0))
path_chain = " -> ".join(str(p) for p in [first_path] + path_chain)
comment = collections.ChainMap(MD5_BAD, MD5_GOOD).get(md5sum, "Unknown MD5")
color_map = {"vulnerable": red, "good": green, "patched": cyan, "unknown": yellow}
comment = collections.ChainMap(
MD5_BAD, MD5_GOOD).get(md5sum, "Unknown MD5")
color_map = {"vulnerable": red, "good": green,
"patched": cyan, "unknown": yellow}
if md5sum in MD5_BAD:
status = "vulnerable" if has_jndilookup else "patched"
elif md5sum in MD5_GOOD:
Expand All @@ -243,7 +247,7 @@ def check_vulnerable(fobj, path_chain, stats, has_jndilookup=True):
status = bold(color(status.upper()))
md5sum = color(md5sum)
comment = bold(color(comment))
print(f"[{now}] {hostname} {status}: {path_chain} [{md5sum}: {comment}]")
check_vulnerable.print_queue.put(f"[{now}] {hostname} {status}: {path_chain} [{md5sum}: {comment}]")


def print_summary(stats):
Expand All @@ -258,6 +262,73 @@ def print_summary(stats):
print(" Found {} patched files".format(stats["patched"]))
if stats["unknown"]:
print(" Found {} unknown files".format(stats["unknown"]))
sys.stdout.flush()


def print_thread(print_queue: multiprocessing.Queue):
"""
Reads a queue and prints items to stdout.
"""
while True:
item = print_queue.get()
if item == "CLOSE":
break
print(item)


def scan_init(print_queue: multiprocessing.Queue, log_level):
"""
Initilizing logging and printing for the child process.
"""
log = multiprocessing.get_logger()
formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
handler = logging.StreamHandler()
handler.setFormatter(formatter)
log.addHandler(handler)
log.setLevel(log_level)

scan.log = log
check_vulnerable.print_queue = print_queue


def scan(p: Path):
"""
Scans a file or directory for known bad or known good filenames and MD5 hashes.
"""
stats = collections.Counter()
if p.name.lower() in FILENAMES:
stats["scanned"] += 1
scan.log.info(f"Found file: {p}")
with p.open("rb") as fobj:
# If we find JndiManager, we also check if JndiLookup.class exists
has_lookup = True
if p.name.lower().endswith("JndiManager.class".lower()):
lookup_path = p.parent.parent / "lookup/JndiLookup.class"
has_lookup = lookup_path.exists()
check_vulnerable(fobj, [p], stats, has_lookup)
if p.suffix.lower() in JAR_EXTENSIONS:
try:
scan.log.info(f"Found jar file: {p}")
stats["scanned"] += 1
for (zinfo, zfile, zpath, parents) in iter_jarfile(
p.open("rb"), parents=[p]
):
scan.log.info(f"Found zfile: {zinfo} ({parents}")
with zfile.open(zinfo.filename) as zf:
# If we find JndiManager.class, we also check if JndiLookup.class exists
has_lookup = True
if zpath.name.lower().endswith("JndiManager.class".lower()):
lookup_path = zpath.parent.parent / "lookup/JndiLookup.class"
try:
has_lookup = zfile.open(
lookup_path.as_posix())
except KeyError:
has_lookup = False
check_vulnerable(
zf, parents + [zpath], stats, has_lookup)
except IOError as e:
scan.log.debug(f"{p}: {e}")
return stats


def main():
Expand Down Expand Up @@ -326,37 +397,28 @@ def main():
now = datetime.datetime.utcnow().replace(microsecond=0)
if not args.quiet:
print(f"[{now}] {hostname} Scanning: {directory}")
for p in iter_scandir(directory, stats=stats, exclude=args.exclude):
if p.name.lower() in FILENAMES:
stats["scanned"] += 1
log.info(f"Found file: {p}")
with p.open("rb") as fobj:
# If we find JndiManager, we also check if JndiLookup.class exists
has_lookup = True
if p.name.lower().endswith("JndiManager.class".lower()):
lookup_path = p.parent.parent / "lookup/JndiLookup.class"
has_lookup = lookup_path.exists()
check_vulnerable(fobj, [p], stats, has_lookup)
if p.suffix.lower() in JAR_EXTENSIONS:
try:
log.info(f"Found jar file: {p}")
stats["scanned"] += 1
for (zinfo, zfile, zpath, parents) in iter_jarfile(
p.open("rb"), parents=[p]
):
log.info(f"Found zfile: {zinfo} ({parents}")
with zfile.open(zinfo.filename) as zf:
# If we find JndiManager.class, we also check if JndiLookup.class exists
has_lookup = True
if zpath.name.lower().endswith("JndiManager.class".lower()):
lookup_path = zpath.parent.parent / "lookup/JndiLookup.class"
try:
has_lookup = zfile.open(lookup_path.as_posix())
except KeyError:
has_lookup = False
check_vulnerable(zf, parents + [zpath], stats, has_lookup)
except IOError as e:
log.debug(f"{p}: {e}")

try:
q = multiprocessing.Queue()

# start printing thread
pt = threading.Thread(target=print_thread, args=(q,))
pt.start()

# Workers = 2*num_cores
pool = multiprocessing.Pool(processes=multiprocessing.cpu_count()*2, initializer=scan_init, initargs=(q, log.getEffectiveLevel()))
res = pool.map(scan, iter_scandir(directory, stats=stats, exclude=args.exclude))
for r in res:
stats += r
except (KeyboardInterrupt, SystemExit, Exception) as ex:
pool.terminate()
raise ex
else:
pool.close()
finally:
q.put("CLOSE")
pool.join()
pt.join()

elapsed = time.monotonic() - start_time
now = datetime.datetime.utcnow().replace(microsecond=0)
Expand All @@ -367,6 +429,7 @@ def main():


if __name__ == "__main__":
multiprocessing.freeze_support()
try:
sys.exit(main())
except KeyboardInterrupt:
Expand Down