From 3ea57a54fede00b7aa42314228ce20993954f1fd Mon Sep 17 00:00:00 2001 From: Tommy Beadle Date: Fri, 24 May 2024 13:43:46 -0400 Subject: [PATCH] Fix main database usage in dist.py. (#2138) Co-authored-by: Tommy Beadle --- utils/dist.py | 374 ++++++++++++++++++++++++++------------------------ 1 file changed, 197 insertions(+), 177 deletions(-) diff --git a/utils/dist.py b/utils/dist.py index d9a8cfa9601..54757d5b2a3 100644 --- a/utils/dist.py +++ b/utils/dist.py @@ -31,7 +31,7 @@ try: import pyzipper except ImportError: - sys.exti("Missed pyzipper dependency: poetry install") + sys.exit("Missed pyzipper dependency: poetry install") CUCKOO_ROOT = os.path.join(os.path.abspath(os.path.dirname(__file__)), "..") sys.path.append(CUCKOO_ROOT) @@ -53,7 +53,7 @@ Database, ) from lib.cuckoo.core.database import Task as MD_Task -from lib.cuckoo.core.database import init_database +from lib.cuckoo.core.database import _Database, init_database dist_conf = Config("distributed") main_server_name = dist_conf.distributed.get("main_server_name", "master") @@ -84,7 +84,7 @@ STATUSES = {} ID2NAME = {} SERVER_TAGS = {} -main_db = Database() +main_db: _Database = Database() dead_count = 5 if dist_conf.distributed.dead_count: @@ -469,7 +469,8 @@ def notification_loop(self): tasks = db.query(Task).filter_by(finished=True, retrieved=True, notificated=False).order_by(Task.id.desc()).all() if tasks is not None: for task in tasks: - main_db.set_status(task.main_task_id, TASK_REPORTED) + with main_db.session.begin(): + main_db.set_status(task.main_task_id, TASK_REPORTED) log.debug("reporting main_task_id: {}".format(task.main_task_id)) for url in urls: try: @@ -494,7 +495,8 @@ def failed_cleaner(self): t = db.query(Task).filter_by(task_id=task["id"], node_id=node.id).order_by(Task.id.desc()).first() if t is not None: log.info("Cleaning failed for id:{}, node:{}: main_task_id: {}".format(t.id, t.node_id, t.main_task_id)) - main_db.set_status(t.main_task_id, TASK_FAILED_REPORTING) + with main_db.session.begin(): + main_db.set_status(t.main_task_id, TASK_FAILED_REPORTING) t.finished = True t.retrieved = True t.notificated = True @@ -590,8 +592,11 @@ def delete_target_file(self, task_id: int, sample_sha256: str, target: str): if cfg.cuckoo.delete_bin_copy: copy_path = os.path.join(CUCKOO_ROOT, "storage", "binaries", sample_sha256) - if path_exists(copy_path) and not main_db.sample_still_used(sample_sha256, task_id): - path_delete(copy_path) + if path_exists(copy_path): + with main_db.session.begin(): + sample_still_used = main_db.sample_still_used(sample_sha256, task_id) + if not sample_still_used: + path_delete(copy_path) # This should be executed as external thread as it generates bottle neck def fetch_latest_reports_nfs(self): @@ -631,10 +636,11 @@ def fetch_latest_reports_nfs(self): t.id, t.task_id, t.main_task_id, ID2NAME[t.node_id] if t.node_id in ID2NAME else t.node_id ) ) - # set completed_on time - main_db.set_status(t.main_task_id, TASK_DISTRIBUTED_COMPLETED) - # set reported time - main_db.set_status(t.main_task_id, TASK_REPORTED) + with main_db.session.begin(): + # set completed_on time + main_db.set_status(t.main_task_id, TASK_DISTRIBUTED_COMPLETED) + # set reported time + main_db.set_status(t.main_task_id, TASK_REPORTED) # Fetch each requested report. report_path = os.path.join(CUCKOO_ROOT, "storage", "analyses", f"{t.main_task_id}") @@ -653,10 +659,12 @@ def fetch_latest_reports_nfs(self): # this doesn't exist for some reason if path_exists(t.path): - sample_sha256 = main_db.find_sample(task_id=t.main_task_id) - if sample_sha256: - sample_sha256 = sample_sha256[0].sample.sha256 - else: + sample_sha256 = None + with main_db.session.begin(): + samples = main_db.find_sample(task_id=t.main_task_id) + if samples: + sample_sha256 = samples[0].sample.sha256 + if sample_sha256 is None: # keep fallback for now sample = open(t.path, "rb").read() sample_sha256 = hashlib.sha256(sample).hexdigest() @@ -726,10 +734,11 @@ def fetch_latest_reports(self): t.id, t.task_id, t.main_task_id, ID2NAME[t.node_id] if t.node_id in ID2NAME else t.node_id ) ) - # set completed_on time - main_db.set_status(t.main_task_id, TASK_DISTRIBUTED_COMPLETED) - # set reported time - main_db.set_status(t.main_task_id, TASK_REPORTED) + with main_db.session.begin(): + # set completed_on time + main_db.set_status(t.main_task_id, TASK_DISTRIBUTED_COMPLETED) + # set reported time + main_db.set_status(t.main_task_id, TASK_REPORTED) # Fetch each requested report. node = db.query(Node).with_entities(Node.id, Node.name, Node.url, Node.apikey).filter_by(id=node_id).first() @@ -768,10 +777,12 @@ def fetch_latest_reports(self): log.error("Permission denied: {}".format(report_path)) if path_exists(t.path): - sample_sha256 = main_db.find_sample(task_id=t.main_task_id) - if sample_sha256: - sample_sha256 = sample_sha256[0].sample.sha256 - else: + sample_sha256 = None + with main_db.session.begin(): + samples = main_db.find_sample(task_id=t.main_task_id) + if samples: + sample_sha256 = samples[0].sample.sha256 + if sample_sha256 is None: # keep fallback for now sample = open(t.path, "rb").read() sample_sha256 = hashlib.sha256(sample).hexdigest() @@ -863,173 +874,181 @@ def submit_tasks(self, node_id, pend_tasks_num, options_like=False, force_push_p for task in bad_tasks: db.delete(task) db.commit() - main_db.set_status(task.main_task_id, TASK_PENDING) + with main_db.session.begin(): + main_db.set_status(task.main_task_id, TASK_PENDING) if node.name != main_server_name: # don"t do nothing if nothing in pending # Get tasks from main_db submitted through web interface # Exclude category - main_db_tasks = main_db.list_tasks( - status=TASK_PENDING, options_like=options_like, limit=pend_tasks_num, order_by=MD_Task.priority.desc() - ) - if not main_db_tasks: - return True - if main_db_tasks: - for t in main_db_tasks: - options = get_options(t.options) - # Check if file exist, if no wipe from db and continue, rare cases - if t.category in ("file", "pcap", "static"): - if not path_exists(t.target): - log.info(f"Task id: {t.id} - File doesn't exist: {t.target}") - main_db.set_status(t.id, TASK_BANNED) - continue - - if not web_conf.general.allow_ignore_size and "ignore_size_check" not in options: - # We can't upload size bigger than X to our workers. In case we extract archive that contains bigger file. - file_size = path_get_size(t.target) - if file_size > web_conf.general.max_sample_size: - log.warning(f"File size: {file_size} is bigger than allowed: {web_conf.general.max_sample_size}") + with main_db.session.begin(): + main_db_tasks = main_db.list_tasks( + status=TASK_PENDING, + options_like=options_like, + limit=pend_tasks_num, + order_by=MD_Task.priority.desc(), + for_update=True, + ) + if not main_db_tasks: + return True + if main_db_tasks: + for t in main_db_tasks: + options = get_options(t.options) + # Check if file exist, if no wipe from db and continue, rare cases + if t.category in ("file", "pcap", "static"): + if not path_exists(t.target): + log.info(f"Task id: {t.id} - File doesn't exist: {t.target}") main_db.set_status(t.id, TASK_BANNED) continue - force_push = False - try: - # check if node exist and its correct + if not web_conf.general.allow_ignore_size and "ignore_size_check" not in options: + # We can't upload size bigger than X to our workers. In case we extract archive that contains bigger file. + file_size = path_get_size(t.target) + if file_size > web_conf.general.max_sample_size: + log.warning( + f"File size: {file_size} is bigger than allowed: {web_conf.general.max_sample_size}" + ) + main_db.set_status(t.id, TASK_BANNED) + continue + + force_push = False + try: + # check if node exist and its correct + if options.get("node"): + requested_node = options.get("node") + if requested_node not in STATUSES: + # if the requested node is not available + force_push = True + elif requested_node != node.name: + # otherwise keep looping + continue + if "timeout=" in t.options: + t.timeout = options.get("timeout", 0) + except Exception as e: + log.error(e, exc_info=True) + # wtf are you doing in pendings? + tasks = db.query(Task).filter_by(main_task_id=t.id).all() + if tasks: + for task in tasks: + # log.info("Deleting incorrectly uploaded file from dist db, main_task_id: {}".format(t.id)) + if node.name == main_server_name: + main_db.set_status(t.id, TASK_RUNNING) + else: + main_db.set_status(t.id, TASK_DISTRIBUTED) + # db.delete(task) + db.commit() + continue + + # Convert array of tags into comma separated list + tags = ",".join([tag.name for tag in t.tags]) + # Append a comma, to make LIKE searches more precise + if tags: + tags += "," + + # sanity check + if "x86" in tags and "x64" in tags: + tags = tags.replace("x86,", "") + + if "msoffice-crypt-tmp" in t.target and "password=" in t.options: + # t.options = t.options.replace(f"password={options['password']}", "") + options["password"] + # if options.get("node"): + # t.options = t.options.replace(f"node={options['node']}", "") if options.get("node"): - requested_node = options.get("node") - if requested_node not in STATUSES: - # if the requested node is not available - force_push = True - elif requested_node != node.name: - # otherwise keep looping - continue - if "timeout=" in t.options: - t.timeout = options.get("timeout", 0) - except Exception as e: - log.error(e, exc_info=True) - # wtf are you doing in pendings? - tasks = db.query(Task).filter_by(main_task_id=t.id).all() - if tasks: - for task in tasks: - # log.info("Deleting incorrectly uploaded file from dist db, main_task_id: {}".format(t.id)) - if node.name == main_server_name: - main_db.set_status(t.id, TASK_RUNNING) - else: - main_db.set_status(t.id, TASK_DISTRIBUTED) - # db.delete(task) - db.commit() - continue + del options["node"] + t.options = ",".join([f"{k}={v}" for k, v in options.items()]) + if t.options: + t.options += "," + + t.options += "main_task_id={}".format(t.id) + args = dict( + package=t.package, + category=t.category, + timeout=t.timeout, + priority=t.priority, + options=t.options, + machine=t.machine, + platform=t.platform, + tags=tags, + custom=t.custom, + memory=t.memory, + clock=t.clock, + enforce_timeout=t.enforce_timeout, + main_task_id=t.id, + route=t.route, + tlp=t.tlp, + ) + task = Task(path=t.target, **args) - # Convert array of tags into comma separated list - tags = ",".join([tag.name for tag in t.tags]) - # Append a comma, to make LIKE searches more precise - if tags: - tags += "," - - # sanity check - if "x86" in tags and "x64" in tags: - tags = tags.replace("x86,", "") - - if "msoffice-crypt-tmp" in t.target and "password=" in t.options: - # t.options = t.options.replace(f"password={options['password']}", "") - options["password"] - # if options.get("node"): - # t.options = t.options.replace(f"node={options['node']}", "") - if options.get("node"): - del options["node"] - t.options = ",".join([f"{k}={v}" for k, v in options.items()]) - if t.options: - t.options += "," - - t.options += "main_task_id={}".format(t.id) - args = dict( - package=t.package, - category=t.category, - timeout=t.timeout, - priority=t.priority, - options=t.options, - machine=t.machine, - platform=t.platform, - tags=tags, - custom=t.custom, - memory=t.memory, - clock=t.clock, - enforce_timeout=t.enforce_timeout, - main_task_id=t.id, - route=t.route, - tlp=t.tlp, - ) - task = Task(path=t.target, **args) + db.add(task) + try: + db.commit() + except Exception as e: + log.exception(e) + log.info("TASK_FAILED_REPORTING") + db.rollback() + log.info(e) + continue - db.add(task) - try: + if force_push or force_push_push: + # Submit appropriate tasks to node + submitted = node_submit_task(task.id, node.id, t.id) + if submitted: + if node.name == main_server_name: + main_db.set_status(t.id, TASK_RUNNING) + else: + main_db.set_status(t.id, TASK_DISTRIBUTED) + limit += 1 + if limit in (pend_tasks_num, len(main_db_tasks)): + db.commit() + log.info("Pushed all tasks") + return True + + # Only get tasks that have not been pushed yet. + q = db.query(Task).filter(or_(Task.node_id.is_(None), Task.task_id.is_(None)), Task.finished.is_(False)) + if q is None: db.commit() - except Exception as e: - log.exception(e) - log.info("TASK_FAILED_REPORTING") - db.rollback() - log.info(e) - continue - - if force_push or force_push_push: - # Submit appropriate tasks to node - submitted = node_submit_task(task.id, node.id, t.id) + return True + # Order by task priority and task id. + q = q.order_by(-Task.priority, Task.main_task_id) + # if we have node set in options push + if dist_conf.distributed.enable_tags: + # Create filter query from tasks in ta + tags = [getattr(Task, "tags") == ""] + for tg in SERVER_TAGS[node.name]: + if len(tg.split(",")) == 1: + tags.append(getattr(Task, "tags") == (tg + ",")) + else: + tg = tg.split(",") + # ie. LIKE "%,%,%," + t_combined = [getattr(Task, "tags").like("%s" % ("%," * len(tg)))] + for tag in tg: + t_combined.append(getattr(Task, "tags").like("%%%s%%" % (tag + ","))) + tags.append(and_(*t_combined)) + # Filter by available tags + q = q.filter(or_(*tags)) + to_upload = q.limit(pend_tasks_num).all() + if not to_upload: + db.commit() + log.info("nothing to upload? How? o_O") + return False + # Submit appropriate tasks to node + log.debug("going to upload {} tasks to node {}".format(pend_tasks_num, node.name)) + for task in to_upload: + submitted = node_submit_task(task.id, node.id, task.main_task_id) if submitted: if node.name == main_server_name: - main_db.set_status(t.id, TASK_RUNNING) + main_db.set_status(task.main_task_id, TASK_RUNNING) else: - main_db.set_status(t.id, TASK_DISTRIBUTED) + main_db.set_status(task.main_task_id, TASK_DISTRIBUTED) + else: + log.info("something is wrong with submission of task: {}".format(task.id)) + db.delete(task) + db.commit() limit += 1 - if limit in (pend_tasks_num, len(main_db_tasks)): + if limit == pend_tasks_num: db.commit() - log.info("Pushed all tasks") return True - - # Only get tasks that have not been pushed yet. - q = db.query(Task).filter(or_(Task.node_id.is_(None), Task.task_id.is_(None)), Task.finished.is_(False)) - if q is None: - db.commit() - return True - # Order by task priority and task id. - q = q.order_by(-Task.priority, Task.main_task_id) - # if we have node set in options push - if dist_conf.distributed.enable_tags: - # Create filter query from tasks in ta - tags = [getattr(Task, "tags") == ""] - for tg in SERVER_TAGS[node.name]: - if len(tg.split(",")) == 1: - tags.append(getattr(Task, "tags") == (tg + ",")) - else: - tg = tg.split(",") - # ie. LIKE "%,%,%," - t_combined = [getattr(Task, "tags").like("%s" % ("%," * len(tg)))] - for tag in tg: - t_combined.append(getattr(Task, "tags").like("%%%s%%" % (tag + ","))) - tags.append(and_(*t_combined)) - # Filter by available tags - q = q.filter(or_(*tags)) - to_upload = q.limit(pend_tasks_num).all() - if not to_upload: - db.commit() - log.info("nothing to upload? How? o_O") - return False - # Submit appropriate tasks to node - log.debug("going to upload {} tasks to node {}".format(pend_tasks_num, node.name)) - for task in to_upload: - submitted = node_submit_task(task.id, node.id, task.main_task_id) - if submitted: - if node.name == main_server_name: - main_db.set_status(task.main_task_id, TASK_RUNNING) - else: - main_db.set_status(task.main_task_id, TASK_DISTRIBUTED) - else: - log.info("something is wrong with submission of task: {}".format(task.id)) - db.delete(task) - db.commit() - limit += 1 - if limit == pend_tasks_num: - db.commit() - return True db.commit() return True @@ -1047,7 +1066,7 @@ def load_vm_tags(self, db, node_id, node_name): SERVER_TAGS[node_name] = list(ta) def run(self): - global main_db, retrieve, STATUSES + global retrieve, STATUSES MINIMUMQUEUE = {} # handle another user case, @@ -1553,10 +1572,11 @@ def init_logging(debug=False): sys.exit() if args.force_reported: - # set completed_on time - main_db.set_status(args.force_reported, TASK_DISTRIBUTED_COMPLETED) - # set reported time - main_db.set_status(args.force_reported, TASK_REPORTED) + with main_db.session.begin(): + # set completed_on time + main_db.set_status(args.force_reported, TASK_DISTRIBUTED_COMPLETED) + # set reported time + main_db.set_status(args.force_reported, TASK_REPORTED) sys.exit() delete_enabled = args.enable_clean