diff --git a/.gitignore b/.gitignore index dfe73515..2ca99756 100644 --- a/.gitignore +++ b/.gitignore @@ -58,7 +58,7 @@ cover/ *.pot # Django stuff: -*.log +*.log* local_settings.py db.sqlite3 db.sqlite3-journal diff --git a/CveXplore/VERSION b/CveXplore/VERSION index d976e9a4..ffa098d2 100644 --- a/CveXplore/VERSION +++ b/CveXplore/VERSION @@ -1 +1 @@ -0.1.8.dev1 \ No newline at end of file +0.1.8.dev2 \ No newline at end of file diff --git a/CveXplore/main.py b/CveXplore/main.py index 47d74aee..ecc44b7a 100644 --- a/CveXplore/main.py +++ b/CveXplore/main.py @@ -13,7 +13,7 @@ from CveXplore.common.db_mapping import database_mapping from CveXplore.database.connection.mongo_db import MongoDBConnection from CveXplore.errors import DatabaseIllegalCollection -from CveXplore.lib.main_updater import MainUpdater +from CveXplore.update.main_updater import MainUpdater try: from version import VERSION diff --git a/CveXplore/update/Config.py b/CveXplore/update/Config.py new file mode 100644 index 00000000..1761700e --- /dev/null +++ b/CveXplore/update/Config.py @@ -0,0 +1,123 @@ +import datetime +import os +import re +import urllib.parse + +import pymongo +import redis + +runPath = os.path.dirname(os.path.realpath(__file__)) + + +class Configuration(object): + CVE_START_YEAR = os.getenv("CVE_START_YEAR", 2002) + + SOURCES = os.getenv( + "SOURCES", + { + "cve": "https://nvd.nist.gov/feeds/json/cve/1.1/", + "cpe": "https://nvd.nist.gov/feeds/json/cpematch/1.0/nvdcpematch-1.0.json.zip", + "cwe": "https://cwe.mitre.org/data/xml/cwec_v4.4.xml.zip", + "capec": "https://capec.mitre.org/data/xml/capec_v3.4.xml", + "via4": "https://www.cve-search.org/feeds/via4.json", + }, + ) + + HTTP_PROXY = os.getenv("HTTP_PROXY", "") + + LOGGING_MAX_FILE_SIZE = os.getenv("LOGGING_MAX_FILE_SIZE", "100MB") + LOGGING_BACKLOG = os.getenv("LOGGING_BACKLOG", 5) + LOGGING_FILE_NAME = os.getenv("LOGGING_FILE_NAME", "./log/update_populate.log") + + MONGO_HOST = os.getenv("MONGO_HOST", "localhost") + MONGO_PORT = os.getenv("MONGO_PORT", 27017) + MONGO_DB = os.getenv("MONGO_DB", "cvexdb") + MONGO_USER = os.getenv("MONGO_USER", "") + MONGO_PASS = os.getenv("MONGO_PASS", "") + + REDIS_HOST = os.getenv("REDIS_HOST", "localhost") + REDIS_PORT = os.getenv("REDIS_PORT", 6379) + REDIS_PASS = os.getenv("REDIS_PASS", None) + REDIS_Q = os.getenv("REDIS_Q", 9) + + @classmethod + def getCVEStartYear(cls): + next_year = datetime.datetime.now().year + 1 + start_year = cls.CVE_START_YEAR + if start_year < cls.CVE_START_YEAR or start_year > next_year: + print( + "The year %i is not a valid year.\ndefault year %i will be used." + % (start_year, cls.default["CVEStartYear"]) + ) + start_year = cls.default["CVEStartYear"] + return start_year + + @classmethod + def getProxy(cls): + return cls.HTTP_PROXY + + @classmethod + def getFeedURL(cls, source): + return cls.SOURCES[source] + + @classmethod + def toPath(cls, path): + return path if os.path.isabs(path) else os.path.join(runPath, "..", path) + + @classmethod + def getUpdateLogFile(cls): + return cls.toPath(cls.LOGGING_FILE_NAME) + + @classmethod + def getMaxLogSize(cls): + size = cls.LOGGING_MAX_FILE_SIZE + split = re.findall("\d+|\D+", size) + multipliers = {"KB": 1024, "MB": 1024 * 1024, "GB": 1024 * 1024 * 1024} + if len(split) == 2: + base = int(split[0]) + unit = split[1].strip().upper() + return base * multipliers.get(unit, 1024 * 1024) + # if size is not a correctly defined set it to 100MB + else: + return 100 * 1024 * 1024 + + @classmethod + def getBacklog(cls): + return cls.LOGGING_BACKLOG + + @classmethod + def getMongoConnection(cls): + mongoHost = cls.MONGO_HOST + mongoPort = cls.MONGO_PORT + mongoDB = cls.MONGO_DB + mongoUsername = urllib.parse.quote(cls.MONGO_USER) + mongoPassword = urllib.parse.quote(cls.MONGO_PASS) + if mongoUsername and mongoPassword: + mongoURI = "mongodb://{username}:{password}@{host}:{port}/{db}".format( + username=mongoUsername, + password=mongoPassword, + host=mongoHost, + port=mongoPort, + db=mongoDB, + ) + else: + mongoURI = "mongodb://{host}:{port}/{db}".format( + host=mongoHost, port=mongoPort, db=mongoDB + ) + connect = pymongo.MongoClient(mongoURI, connect=False) + return connect[mongoDB] + + @classmethod + def getRedisQConnection(cls): + redisHost = cls.REDIS_HOST + redisPort = cls.REDIS_PORT + redisDB = cls.REDIS_Q + redisPass = cls.REDIS_PASS + return redis.Redis( + host=redisHost, + port=redisPort, + db=redisDB, + password=redisPass, + charset="utf-8", + decode_responses=True, + ) diff --git a/CveXplore/update/DatabaseLayer.py b/CveXplore/update/DatabaseLayer.py new file mode 100644 index 00000000..8f9c1b60 --- /dev/null +++ b/CveXplore/update/DatabaseLayer.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Database layer translates database calls to functions +# +# Software is free software released under the "GNU Affero General Public License v3.0" +# +# Copyright (c) 2015-2018 Pieter-Jan Moreels - pieterjan.moreels@gmail.com + +# imports + +import pymongo + +from .Config import Configuration as conf + +# Variables +db = conf.getMongoConnection() +colCVE = db["cves"] +colCPE = db["cpe"] +colCWE = db["cwe"] +colCPEOTHER = db["cpeother"] +colINFO = db["info"] +colVIA4 = db["via4"] +colCAPEC = db["capec"] + +mongo_version = db.command("buildinfo")["versionArray"] +# to check if mongodb > 4.4 +# if it is, then use allow_disk_use for optimized queries +# to be removed in future with the conditional statements +# and use allow_disk_use by default + +# Functions +def sanitize(x): + if type(x) == pymongo.cursor.Cursor: + x = list(x) + if type(x) == list: + for y in x: + sanitize(y) + if x and "_id" in x: + x.pop("_id") + return x + + +# DB Functions +def ensureIndex(collection, field, **kwargs): + db[collection].create_index(field, **kwargs) + + +def drop(collection): + db[collection].drop() + + +def setColUpdate(collection, date): + colINFO.update({"db": collection}, {"$set": {"last-modified": date}}, upsert=True) + + +def setColInfo(collection, field, data): + colINFO.update({"db": collection}, {"$set": {field: data}}, upsert=True) + + +def updateCVE(cve): + if cve["cvss3"] is not None: + colCVE.update( + {"id": cve["id"]}, + { + "$set": { + "cvss3": cve["cvss3"], + "impact3": cve["impact3"], + "exploitability3": cve["exploitability3"], + "cvss3-vector": cve["cvss3-vector"], + "impactScore3": cve["impactScore3"], + "exploitabilityScore3": cve["exploitabilityScore3"], + "cvss": cve["cvss"], + "summary": cve["summary"], + "references": cve["references"], + "impact": cve["impact"], + "vulnerable_product": cve["vulnerable_product"], + "access": cve["access"], + "cwe": cve["cwe"], + "vulnerable_configuration": cve["vulnerable_configuration"], + "vulnerable_configuration_cpe_2_2": cve[ + "vulnerable_configuration_cpe_2_2" + ], + "last-modified": cve["Modified"], + } + }, + upsert=True, + ) + else: + colCVE.update( + {"id": cve["id"]}, + { + "$set": { + "cvss3": cve["cvss3"], + "cvss": cve["cvss"], + "summary": cve["summary"], + "references": cve["references"], + "impact": cve["impact"], + "vulnerable_product": cve["vulnerable_product"], + "access": cve["access"], + "cwe": cve["cwe"], + "vulnerable_configuration": cve["vulnerable_configuration"], + "vulnerable_configuration_cpe_2_2": cve[ + "vulnerable_configuration_cpe_2_2" + ], + "last-modified": cve["Modified"], + } + }, + upsert=True, + ) + + +def dropCollection(col): + return db[col].drop() + # jdt_NOTE: is exactly the same as drop(collection) + # jdt_NOTE: use only one of them + + +def getTableNames(): + # return db.collection_names() + # jdt_NOTE: collection_names() is depreated, list_collection_names() should be used instead + return db.list_collection_names() + + +def getCPEVersionInformation(query): + return sanitize(colCPE.find_one(query)) + + +def getCPEs(): + return sanitize(colCPE.find()) + + +def getInfo(collection): + return sanitize(colINFO.find_one({"db": collection})) diff --git a/CveXplore/update/DownloadHandler.py b/CveXplore/update/DownloadHandler.py new file mode 100644 index 00000000..bff94e0f --- /dev/null +++ b/CveXplore/update/DownloadHandler.py @@ -0,0 +1,427 @@ +import datetime +import gzip +import logging +import multiprocessing as mp +import os +import sys +import tempfile +import threading +import time +import zipfile +from abc import ABC, abstractmethod +from datetime import timedelta +from io import BytesIO +from itertools import islice +from shutil import copy + +import requests +from dateutil.parser import parse as parse_datetime +from pymongo.errors import BulkWriteError +from requests.adapters import HTTPAdapter +from tqdm import tqdm +from tqdm.contrib.concurrent import thread_map +from urllib3 import Retry + +from .DatabaseLayer import getInfo, setColUpdate +from .LogHandler import UpdateHandler +from .redis_q import RedisQueue, CveXploreQueue +from .Config import Configuration + +thread_local = threading.local() +logging.setLoggerClass(UpdateHandler) + +logging.getLogger("urllib3").setLevel(logging.WARNING) + + +class DownloadHandler(ABC): + """ + DownloadHandler is the base class for all downloads and subsequent processing of the downloaded content. + Each download script has a derived class which handles specifics for that type of content / download. + """ + + def __init__(self, feed_type, prefix=None): + self._end = None + + self.feed_type = feed_type + + self.prefix = prefix + + self.queue = RedisQueue(name=self.feed_type) + + self.file_queue = RedisQueue(name=f"{self.feed_type}:files") + self.file_queue.clear() + + self.progress_bar = None + + self.last_modified = None + + self.do_process = True + + self.logger = logging.getLogger("DownloadHandler") + + self.config = Configuration() + + def __repr__(self): + """ return string representation of object """ + return "<< DownloadHandler:{} >>".format(self.feed_type) + + def get_session( + self, + retries=3, + backoff_factor=0.3, + status_forcelist=(429, 500, 502, 503, 504), + session=None, + ): + """ + Method for returning a session object per every requesting thread + """ + + proxies = {"http": self.config.getProxy(), "https": self.config.getProxy()} + + if not hasattr(thread_local, "session"): + session = session or requests.Session() + retry = Retry( + total=retries, + read=retries, + connect=retries, + backoff_factor=backoff_factor, + status_forcelist=status_forcelist, + ) + + session.proxies.update(proxies) + + adapter = HTTPAdapter(max_retries=retry) + session.mount("http://", adapter) + session.mount("https://", adapter) + + thread_local.session = session + + return thread_local.session + + def process_downloads(self, sites, collection): + """ + Method to download and process files + + :param sites: List of file to download and process + :type sites: list + :param collection: Mongodb Collection name + :type collection: str + :return: + :rtype: + """ + + worker_size = ( + int(os.getenv("WORKER_SIZE")) + if os.getenv("WORKER_SIZE") + else min(32, os.cpu_count() + 4) + ) + + start_time = time.time() + + thread_map(self.download_site, sites, desc="Downloading files") + + if self.do_process: + thread_map( + self.file_to_queue, + self.file_queue.get_full_list(), + desc="Processing downloaded files", + ) + + self._process_queue_to_db(worker_size, collection=collection) + + # checking if last-modified was in the response headers and not set to default + if "01-01-1970" != self.last_modified.strftime("%d-%m-%Y"): + setColUpdate(self.feed_type.lower(), self.last_modified) + + self.logger.info( + "Duration: {}".format(timedelta(seconds=time.time() - start_time)) + ) + + def chunk_list(self, lst, number): + """ + Yield successive n-sized chunks from lst. + + :param lst: List to be chunked + :type lst: list + :param number: Chunk size + :type number: int + :return: Chunked list + :rtype: list + """ + for i in range(0, len(lst), number): + yield lst[i : i + number] + + def _handle_queue_progressbar(self, description): + """ + Method for handling the progressbar during queue processing + + :param description: Description for tqdm progressbar + :type description: str + """ + max_len = self.queue.qsize() + + pbar = tqdm(total=max_len, desc=description) + not_Done = True + q_len = max_len + dif_old = 0 + x = 0 + + while not_Done: + + current_q_len = self.queue.qsize() + + if x % 10 == 0: + # log stats the first cycle and every 10th cycle thereafter + self.logger.debug( + "Queue max_len: {}, current_q_len: {}, q_len: {}, dif_old: {}, cycle: {}".format( + max_len, current_q_len, q_len, dif_old, x + ) + ) + + if current_q_len != 0: + + if current_q_len != q_len: + q_len = current_q_len + dif = max_len - q_len + + pbar.update(int(dif - dif_old)) + + dif_old = dif + else: + pbar.update(int(max_len - dif_old)) + not_Done = False + + x += 1 + time.sleep(5) + + self.logger.debug( + "Queue max_len: {}, q_len: {}, dif_old: {}, cycles: {}".format( + max_len, q_len, dif_old, x + ) + ) + + pbar.close() + + def _process_queue_to_db(self, max_workers, collection): + """ + Method to write the queued database transactions into the database given a Queue reference and Collection name + + :param max_workers: Max amount of worker processes to use; defaults to min(32, os.cpu_count() + 4) + :type max_workers: int + :param collection: Mongodb Collection name + :type collection: str + """ + + pbar = mp.Process( + target=self._handle_queue_progressbar, + args=("Transferring queue to database",), + ) + + processes = [ + mp.Process(target=self._db_bulk_writer, args=(collection,)) + for _ in range(max_workers) + ] + for proc in processes: + proc.start() + # Put triggers in the Queue to tell the workers to exit their for-loop + self.queue.put(self._end) + + pbar.start() + + for proc in processes: + proc.join() + + pbar.join() + + def _db_bulk_writer(self, collection, threshold=1000): + """ + Method to act as worker for writing queued entries into the database + + :param collection: Mongodb Collection name + :type collection: str + :param threshold: Batch size threshold; defaults to 1000 + :type threshold: int + """ + database = self.config.getMongoConnection() + + for batch in iter(lambda: list(islice(self.queue, threshold)), []): + try: + database[collection].bulk_write(batch, ordered=False) + except BulkWriteError as err: + self.logger.debug("Error during bulk write: {}".format(err)) + pass + + def store_file(self, response_content, content_type, url): + """ + Method to store the download based on the headers content type + + :param response_content: Response content + :type response_content: bytes + :param content_type: Content type; e.g. 'application/zip' + :type content_type: str + :param url: Download url + :type url: str + :return: A working directory and a filename + :rtype: str and str + """ + wd = tempfile.mkdtemp() + filename = None + + if ( + content_type == "application/zip" + or content_type == "application/x-zip" + or content_type == "application/x-zip-compressed" + or content_type == "application/zip-compressed" + ): + filename = os.path.join(wd, url.split("/")[-1][:-4]) + self.logger.debug("Saving file to: {}".format(filename)) + + with zipfile.ZipFile(BytesIO(response_content)) as zip_file: + zip_file.extractall(wd) + + elif ( + content_type == "application/x-gzip" + or content_type == "application/gzip" + or content_type == "application/x-gzip-compressed" + or content_type == "application/gzip-compressed" + ): + filename = os.path.join(wd, url.split("/")[-1][:-3]) + self.logger.debug("Saving file to: {}".format(filename)) + + buf = BytesIO(response_content) + with open(filename, "wb") as f: + f.write(gzip.GzipFile(fileobj=buf).read()) + + elif content_type == "application/json" or content_type == "application/xml": + filename = os.path.join(wd, url.split("/")[-1]) + self.logger.debug("Saving file to: {}".format(filename)) + + with open(filename, "wb") as output_file: + output_file.write(response_content) + + elif content_type == "application/local": + filename = os.path.join(wd, url.split("/")[-1]) + self.logger.debug("Saving file to: {}".format(filename)) + + copy(url[7:], filename) + + else: + self.logger.error( + "Unhandled Content-Type encountered: {} from url".format( + content_type, url + ) + ) + sys.exit(1) + + return wd, filename + + def download_site(self, url): + if url[:4] == "file": + self.logger.info("Scheduling local hosted file: {}".format(url)) + + # local file do not get last_modified header; so completely ignoring last_modified check and always asume + # local file == the last modified file and set to current time. + self.last_modified = datetime.datetime.now() + + self.logger.debug( + "Last {} modified value: {} for URL: {}".format( + self.feed_type, self.last_modified, url + ) + ) + + wd, filename = self.store_file( + response_content=b"local", content_type="application/local", url=url + ) + + if filename is not None: + self.file_queue.put((wd, filename)) + else: + self.logger.error( + "Unable to retrieve a filename; something went wrong when trying to save the file" + ) + sys.exit(1) + + else: + self.logger.debug("Downloading from url: {}".format(url)) + session = self.get_session() + try: + with session.get(url) as response: + try: + self.last_modified = parse_datetime( + response.headers["last-modified"], ignoretz=True + ) + except KeyError: + self.logger.error( + "Did not receive last-modified header in the response; setting to default " + "(01-01-1970) and force update! Headers received: {}".format( + response.headers + ) + ) + # setting to last_modified to default value + self.last_modified = parse_datetime("01-01-1970") + + self.logger.debug( + "Last {} modified value: {} for URL: {}".format( + self.feed_type, self.last_modified, url + ) + ) + + i = getInfo(self.feed_type.lower()) + + if i is not None: + if self.last_modified == i["last-modified"]: + self.logger.info( + "{}'s are not modified since the last update".format( + self.feed_type + ) + ) + self.file_queue.get_full_list() + self.do_process = False + if self.do_process: + content_type = response.headers["content-type"] + + self.logger.debug( + "URL: {} fetched Content-Type: {}".format(url, content_type) + ) + + wd, filename = self.store_file( + response_content=response.content, + content_type=content_type, + url=url, + ) + + if filename is not None: + self.file_queue.put((wd, filename)) + else: + self.logger.error( + "Unable to retrieve a filename; something went wrong when trying to save the file" + ) + sys.exit(1) + except Exception as err: + self.logger.info( + "Exception encountered during download from: {}. Please check the logs for more information!".format( + url + ) + ) + self.logger.error( + "Exception encountered during the download from: {}. Error encountered: {}".format( + url, err + ) + ) + self.do_process = False + + @abstractmethod + def process_item(self, **kwargs): + raise NotImplementedError + + @abstractmethod + def file_to_queue(self, *args): + raise NotImplementedError + + @abstractmethod + def update(self, **kwargs): + raise NotImplementedError + + @abstractmethod + def populate(self, **kwargs): + raise NotImplementedError diff --git a/CveXplore/update/IJSONHandler.py b/CveXplore/update/IJSONHandler.py new file mode 100644 index 00000000..71c799fd --- /dev/null +++ b/CveXplore/update/IJSONHandler.py @@ -0,0 +1,25 @@ +import logging + +import ijson + +from .LogHandler import UpdateHandler + +logging.setLoggerClass(UpdateHandler) + + +class IJSONHandler(object): + def __init__(self): + self.logger = logging.getLogger("IJSONHandler") + + def fetch(self, filename, prefix): + x = 0 + with open(filename, "rb") as input_file: + for item in ijson.items(input_file, prefix): + yield item + x += 1 + + self.logger.debug( + "Processed {} items from file: {}, using prefix: {}".format( + x, filename, prefix + ) + ) diff --git a/CveXplore/update/JSONFileHandler.py b/CveXplore/update/JSONFileHandler.py new file mode 100644 index 00000000..3039ba4f --- /dev/null +++ b/CveXplore/update/JSONFileHandler.py @@ -0,0 +1,56 @@ +import shutil +from abc import abstractmethod + +from .DownloadHandler import DownloadHandler +from .IJSONHandler import IJSONHandler + + +class JSONFileHandler(DownloadHandler): + def __init__(self, feed_type, prefix): + super().__init__(feed_type) + + self.is_update = True + + self.prefix = prefix + + self.ijson_handler = IJSONHandler() + + def __repr__(self): + """ return string representation of object """ + return "<< JSONFileHandler:{} >>".format(self.feed_type) + + def file_to_queue(self, file_tuple): + + working_dir, filename = file_tuple + + # adjust the interval counter for debug logging when updating + if self.is_update: + interval = 500 + else: + interval = 5000 + + x = 0 + self.logger.debug("Starting processing of file: {}".format(filename)) + for cpe in self.ijson_handler.fetch(filename=filename, prefix=self.prefix): + self.process_item(item=cpe) + x += 1 + if x % interval == 0: + self.logger.debug( + "Processed {} entries from file: {}".format(x, filename) + ) + + try: + self.logger.debug("Removing working dir: {}".format(working_dir)) + shutil.rmtree(working_dir) + except Exception as err: + self.logger.error( + "Failed to remove working dir; error produced: {}".format(err) + ) + + @abstractmethod + def update(self, **kwargs): + raise NotImplementedError + + @abstractmethod + def populate(self, **kwargs): + raise NotImplementedError diff --git a/CveXplore/update/LogHandler.py b/CveXplore/update/LogHandler.py new file mode 100644 index 00000000..17119c21 --- /dev/null +++ b/CveXplore/update/LogHandler.py @@ -0,0 +1,177 @@ +""" +LogHandler.py +============= +""" +import logging +import os +import platform +from logging.config import dictConfig +from logging.handlers import RotatingFileHandler + +import colors + +from .Config import Configuration + + +class HostnameFilter(logging.Filter): + hostname = platform.node() + + def filter(self, record): + record.hostname = HostnameFilter.hostname + return True + + +class HelperLogger(logging.Logger): + """ + The HelperLogger is used by the application / gui as their logging class and *extends* the default python + logger.logging class. + + This will separate the logging from the application / gui from that of the daemons. + + """ + + runPath = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) + + logPath = os.path.join(runPath, "log") + + if not os.path.exists(logPath): + os.makedirs(logPath) + + config = Configuration() + + logDict = { + "version": 1, + "formatters": { + "sysLogFormatter": { + "format": "%(asctime)s - %(name)-8s - %(levelname)-8s - %(message)s" + }, + "simpleFormatter": { + "format": "%(asctime)s - %(name)-8s - %(levelname)-8s - %(message)s" + }, + }, + "handlers": { + "consoleHandler": { + "class": "logging.StreamHandler", + "level": "INFO", + "stream": "ext://sys.stdout", + "formatter": "simpleFormatter", + } + }, + "root": {"level": "DEBUG", "handlers": ["consoleHandler"]}, + } + + dictConfig(logDict) + + level_map = { + "debug": "magenta", + "info": "white", + "warning": "yellow", + "error": "red", + "critical": "red", + } + + def __init__(self, name, level=logging.NOTSET): + + super().__init__(name, level) + + def debug(self, msg, *args, **kwargs): + """ + Log ‘msg % args’ with severity ‘DEBUG’ and color *MAGENTA. + + To pass exception information, use the keyword argument exc_info with a true value, e.g. + + logger.debug(“Houston, we have a %s”, “thorny problem”, exc_info=1) + + :param msg: Message to log + :type msg: str + """ + + msg = colors.color("{}".format(msg), fg=HelperLogger.level_map["debug"]) + + return super(HelperLogger, self).debug(msg, *args, **kwargs) + + def info(self, msg, *args, **kwargs): + """ + Log ‘msg % args’ with severity ‘INFO’ and color *WHITE*. + + To pass exception information, use the keyword argument exc_info with a true value, e.g. + + logger.info(“Houston, we have a %s”, “interesting problem”, exc_info=1) + + :param msg: Message to log + :type msg: str + """ + + msg = colors.color("{}".format(msg), fg=HelperLogger.level_map["info"]) + + return super(HelperLogger, self).info(msg, *args, **kwargs) + + def warning(self, msg, *args, **kwargs): + """ + Log ‘msg % args’ with severity ‘WARNING’ and color *YELLOW*. + + To pass exception information, use the keyword argument exc_info with a true value, e.g. + + logger.warning(“Houston, we have a %s”, “bit of a problem”, exc_info=1) + + :param msg: Message to log + :type msg: str + """ + + msg = colors.color("{}".format(msg), fg=HelperLogger.level_map["warning"]) + + return super(HelperLogger, self).warning(msg, *args, **kwargs) + + def error(self, msg, *args, **kwargs): + """ + Log ‘msg % args’ with severity ‘ERROR’ and color *RED*. + + Store logged message to the database for dashboard alerting. + + To pass exception information, use the keyword argument exc_info with a true value, e.g. + + logger.error(“Houston, we have a %s”, “major problem”, exc_info=1) + + :param msg: Message to log + :type msg: str + """ + + msg = colors.color("{}".format(msg), fg=HelperLogger.level_map["error"]) + + return super(HelperLogger, self).error(msg, *args, **kwargs) + + def critical(self, msg, *args, **kwargs): + """ + Log ‘msg % args’ with severity ‘CRITICAL’ and color *RED*. + + Store logged message to the database for dashboard alerting. + + To pass exception information, use the keyword argument exc_info with a true value, e.g. + + logger.critical(“Houston, we have a %s”, “hell of a problem”, exc_info=1) + + :param msg: Message to log + :type msg: str + """ + + msg = colors.color("{}".format(msg), fg=HelperLogger.level_map["critical"]) + + return super(HelperLogger, self).critical(msg, *args, **kwargs) + + +class UpdateHandler(HelperLogger): + def __init__(self, name, level=logging.NOTSET): + super().__init__(name, level) + + formatter = logging.Formatter( + "%(asctime)s - %(name)-8s - %(levelname)-8s - %(message)s" + ) + + crf = RotatingFileHandler( + filename=self.config.getUpdateLogFile(), + maxBytes=self.config.getMaxLogSize(), + backupCount=self.config.getBacklog(), + ) + crf.setLevel(logging.DEBUG) + crf.setFormatter(formatter) + self.addHandler(crf) diff --git a/CveXplore/update/Sources_process.py b/CveXplore/update/Sources_process.py new file mode 100644 index 00000000..68076418 --- /dev/null +++ b/CveXplore/update/Sources_process.py @@ -0,0 +1,1001 @@ +import datetime +import hashlib +import json +import logging +import shutil +from collections import namedtuple +from xml.sax import make_parser + +from dateutil.parser import parse as parse_datetime +from pymongo import TEXT, ASCENDING + +from .Config import Configuration +from .DatabaseLayer import ( + getTableNames, + dropCollection, + getCPEVersionInformation, + setColInfo, + getInfo, + ensureIndex, +) +from .JSONFileHandler import JSONFileHandler +from .Toolkit import generate_title +from .XMLFileHandler import XMLFileHandler +from .content_handlers import CapecHandler, CWEHandler +from .db_action import DatabaseAction + +# init parts of the file names to enable looped file download +file_prefix = "nvdcve-1.1-" +file_suffix = ".json.gz" +file_mod = "modified" +file_rec = "recent" + +date = datetime.datetime.now() +year = date.year + 1 + +# default config +defaultvalue = {"cwe": "Unknown"} + +cveStartYear = Configuration.getCVEStartYear() + + +class CPEDownloads(JSONFileHandler): + def __init__(self): + self.feed_type = "CPE" + self.prefix = "matches.item" + super().__init__(self.feed_type, self.prefix) + + self.feed_url = Configuration.getFeedURL(self.feed_type.lower()) + + self.logger = logging.getLogger("CPEDownloads") + + @staticmethod + def process_cpe_item(item=None): + if item is None: + return None + if "cpe23Uri" not in item: + return None + + cpe = { + "title": generate_title(item["cpe23Uri"]), + "cpe_2_2": item["cpe23Uri"], + "cpe_name": item["cpe_name"], + "vendor": item["cpe23Uri"].split(":")[3], + "product": item["cpe23Uri"].split(":")[4], + } + + version_info = "" + if "versionStartExcluding" in item: + cpe["versionStartExcluding"] = item["versionStartExcluding"] + version_info += cpe["versionStartExcluding"] + "_VSE" + if "versionStartIncluding" in item: + cpe["versionStartIncluding"] = item["versionStartIncluding"] + version_info += cpe["versionStartIncluding"] + "_VSI" + if "versionEndExcluding" in item: + cpe["versionEndExcluding"] = item["versionEndExcluding"] + version_info += cpe["versionEndExcluding"] + "_VEE" + if "versionEndIncluding" in item: + cpe["versionEndIncluding"] = item["versionEndIncluding"] + version_info += cpe["versionEndIncluding"] + "_VEI" + + sha1_hash = hashlib.sha1( + cpe["cpe_2_2"].encode("utf-8") + version_info.encode("utf-8") + ).hexdigest() + + cpe["id"] = sha1_hash + + return cpe + + def process_item(self, item): + cpe = self.process_cpe_item(item) + + if cpe is not None: + if self.is_update: + self.queue.put( + DatabaseAction( + action=DatabaseAction.actions.UpdateOne, + collection=self.feed_type.lower(), + doc=cpe, + ) + ) + else: + self.queue.put( + DatabaseAction( + action=DatabaseAction.actions.InsertOne, + collection=self.feed_type.lower(), + doc=cpe, + ) + ) + + def update(self, **kwargs): + self.logger.info("CPE database update started") + + # if collection is non-existent; assume it's not an update + if self.feed_type.lower() not in getTableNames(): + self.is_update = False + + self.process_downloads([self.feed_url], collection=self.feed_type.lower()) + + self.logger.info("Finished CPE database update") + + return self.last_modified + + def populate(self, **kwargs): + self.logger.info("CPE Database population started") + + self.queue.clear() + + dropCollection(self.feed_type.lower()) + + DatabaseIndexer().create_indexes(collection="cpe") + + self.is_update = False + + self.process_downloads([self.feed_url], collection=self.feed_type.lower()) + + self.logger.info("Finished CPE database population") + + return self.last_modified + + +class CVEDownloads(JSONFileHandler): + def __init__(self): + self.feed_type = "CVES" + self.prefix = "CVE_Items.item" + super().__init__(self.feed_type, self.prefix) + + self.feed_url = Configuration.getFeedURL("cve") + self.modfile = file_prefix + file_mod + file_suffix + self.recfile = file_prefix + file_rec + file_suffix + + self.logger = logging.getLogger("CVEDownloads") + + @staticmethod + def get_cve_year_range(): + """ + Method to fetch the years where we need cve's for + """ + for a_year in range(cveStartYear, year): + yield a_year + + @staticmethod + def get_cpe_info(cpeuri): + query = {} + version_info = "" + if "versionStartExcluding" in cpeuri: + query["versionStartExcluding"] = cpeuri["versionStartExcluding"] + version_info += query["versionStartExcluding"] + "_VSE" + if "versionStartIncluding" in cpeuri: + query["versionStartIncluding"] = cpeuri["versionStartIncluding"] + version_info += query["versionStartIncluding"] + "_VSI" + if "versionEndExcluding" in cpeuri: + query["versionEndExcluding"] = cpeuri["versionEndExcluding"] + version_info += query["versionEndExcluding"] + "_VEE" + if "versionEndIncluding" in cpeuri: + query["versionEndIncluding"] = cpeuri["versionEndIncluding"] + version_info += query["versionEndIncluding"] + "_VEI" + + return query, version_info + + @staticmethod + def add_if_missing(cve, key, value): + if value not in cve[key]: + cve[key].append(value) + return cve + + @staticmethod + def get_vendor_product(cpeUri): + vendor = cpeUri.split(":")[3] + product = cpeUri.split(":")[4] + return vendor, product + + @staticmethod + def stem(cpeUri): + cpeArr = cpeUri.split(":") + return ":".join(cpeArr[:5]) + + def process_cve_item(self, item=None): + if item is None: + return None + if "ASSIGNER" not in item["cve"]["CVE_data_meta"]: + item["cve"]["CVE_data_meta"]["ASSIGNER"] = None + + cve = { + "id": item["cve"]["CVE_data_meta"]["ID"], + "assigner": item["cve"]["CVE_data_meta"]["ASSIGNER"], + "Published": parse_datetime(item["publishedDate"], ignoretz=True), + "Modified": parse_datetime(item["lastModifiedDate"], ignoretz=True), + "last-modified": parse_datetime(item["lastModifiedDate"], ignoretz=True), + } + + for description in item["cve"]["description"]["description_data"]: + if description["lang"] == "en": + if "summary" in cve: + cve["summary"] += " {}".format(description["value"]) + else: + cve["summary"] = description["value"] + if "impact" in item: + cve["access"] = {} + cve["impact"] = {} + if "baseMetricV3" in item["impact"]: + cve["impact3"] = {} + cve["exploitability3"] = {} + cve["impact3"]["availability"] = item["impact"]["baseMetricV3"][ + "cvssV3" + ]["availabilityImpact"] + cve["impact3"]["confidentiality"] = item["impact"]["baseMetricV3"][ + "cvssV3" + ]["confidentialityImpact"] + cve["impact3"]["integrity"] = item["impact"]["baseMetricV3"]["cvssV3"][ + "integrityImpact" + ] + cve["exploitability3"]["attackvector"] = item["impact"]["baseMetricV3"][ + "cvssV3" + ]["attackVector"] + cve["exploitability3"]["attackcomplexity"] = item["impact"][ + "baseMetricV3" + ]["cvssV3"]["attackComplexity"] + cve["exploitability3"]["privilegesrequired"] = item["impact"][ + "baseMetricV3" + ]["cvssV3"]["privilegesRequired"] + cve["exploitability3"]["userinteraction"] = item["impact"][ + "baseMetricV3" + ]["cvssV3"]["userInteraction"] + cve["exploitability3"]["scope"] = item["impact"]["baseMetricV3"][ + "cvssV3" + ]["scope"] + cve["cvss3"] = float( + item["impact"]["baseMetricV3"]["cvssV3"]["baseScore"] + ) + cve["cvss3-vector"] = item["impact"]["baseMetricV3"]["cvssV3"][ + "vectorString" + ] + cve["impactScore3"] = float( + item["impact"]["baseMetricV3"]["impactScore"] + ) + cve["exploitabilityScore3"] = float( + item["impact"]["baseMetricV3"]["exploitabilityScore"] + ) + else: + cve["cvss3"] = None + if "baseMetricV2" in item["impact"]: + cve["access"]["authentication"] = item["impact"]["baseMetricV2"][ + "cvssV2" + ]["authentication"] + cve["access"]["complexity"] = item["impact"]["baseMetricV2"]["cvssV2"][ + "accessComplexity" + ] + cve["access"]["vector"] = item["impact"]["baseMetricV2"]["cvssV2"][ + "accessVector" + ] + cve["impact"]["availability"] = item["impact"]["baseMetricV2"][ + "cvssV2" + ]["availabilityImpact"] + cve["impact"]["confidentiality"] = item["impact"]["baseMetricV2"][ + "cvssV2" + ]["confidentialityImpact"] + cve["impact"]["integrity"] = item["impact"]["baseMetricV2"]["cvssV2"][ + "integrityImpact" + ] + cve["cvss"] = float( + item["impact"]["baseMetricV2"]["cvssV2"]["baseScore"] + ) + cve["exploitabilityScore"] = float( + item["impact"]["baseMetricV2"]["exploitabilityScore"] + ) + cve["impactScore"] = float( + item["impact"]["baseMetricV2"]["impactScore"] + ) + cve["cvss-time"] = parse_datetime( + item["lastModifiedDate"], ignoretz=True + ) # NVD JSON lacks the CVSS time which was present in the original XML format + cve["cvss-vector"] = item["impact"]["baseMetricV2"]["cvssV2"][ + "vectorString" + ] + else: + cve["cvss"] = None + if "references" in item["cve"]: + cve["references"] = [] + for ref in item["cve"]["references"]["reference_data"]: + cve["references"].append(ref["url"]) + if "configurations" in item: + cve["vulnerable_configuration"] = [] + cve["vulnerable_product"] = [] + cve["vendors"] = [] + cve["products"] = [] + cve["vulnerable_product_stems"] = [] + cve["vulnerable_configuration_stems"] = [] + for cpe in item["configurations"]["nodes"]: + if "cpe_match" in cpe: + for cpeuri in cpe["cpe_match"]: + if "cpe23Uri" not in cpeuri: + continue + if cpeuri["vulnerable"]: + query, version_info = self.get_cpe_info(cpeuri) + if query != {}: + query["id"] = hashlib.sha1( + cpeuri["cpe23Uri"].encode("utf-8") + + version_info.encode("utf-8") + ).hexdigest() + cpe_info = getCPEVersionInformation(query) + if cpe_info: + if cpe_info["cpe_name"]: + for vulnerable_version in cpe_info["cpe_name"]: + cve = self.add_if_missing( + cve, + "vulnerable_product", + vulnerable_version["cpe23Uri"], + ) + cve = self.add_if_missing( + cve, + "vulnerable_configuration", + vulnerable_version["cpe23Uri"], + ) + cve = self.add_if_missing( + cve, + "vulnerable_configuration_stems", + self.stem( + vulnerable_version["cpe23Uri"] + ), + ) + vendor, product = self.get_vendor_product( + vulnerable_version["cpe23Uri"] + ) + cve = self.add_if_missing( + cve, "vendors", vendor + ) + cve = self.add_if_missing( + cve, "products", product + ) + cve = self.add_if_missing( + cve, + "vulnerable_product_stems", + self.stem( + vulnerable_version["cpe23Uri"] + ), + ) + else: + cve = self.add_if_missing( + cve, + "vulnerable_product", + cpeuri["cpe23Uri"], + ) + cve = self.add_if_missing( + cve, + "vulnerable_configuration", + cpeuri["cpe23Uri"], + ) + cve = self.add_if_missing( + cve, + "vulnerable_configuration_stems", + self.stem(cpeuri["cpe23Uri"]), + ) + vendor, product = self.get_vendor_product( + cpeuri["cpe23Uri"] + ) + cve = self.add_if_missing( + cve, "vendors", vendor + ) + cve = self.add_if_missing( + cve, "products", product + ) + cve = self.add_if_missing( + cve, + "vulnerable_product_stems", + self.stem(cpeuri["cpe23Uri"]), + ) + else: + # If the cpe_match did not have any of the version start/end modifiers, + # add the CPE string as it is. + cve = self.add_if_missing( + cve, "vulnerable_product", cpeuri["cpe23Uri"] + ) + cve = self.add_if_missing( + cve, "vulnerable_configuration", cpeuri["cpe23Uri"] + ) + cve = self.add_if_missing( + cve, + "vulnerable_configuration_stems", + self.stem(cpeuri["cpe23Uri"]), + ) + vendor, product = self.get_vendor_product( + cpeuri["cpe23Uri"] + ) + cve = self.add_if_missing(cve, "vendors", vendor) + cve = self.add_if_missing(cve, "products", product) + cve = self.add_if_missing( + cve, + "vulnerable_product_stems", + self.stem(cpeuri["cpe23Uri"]), + ) + else: + cve = self.add_if_missing( + cve, "vulnerable_configuration", cpeuri["cpe23Uri"] + ) + cve = self.add_if_missing( + cve, + "vulnerable_configuration_stems", + self.stem(cpeuri["cpe23Uri"]), + ) + if "children" in cpe: + for child in cpe["children"]: + if "cpe_match" in child: + for cpeuri in child["cpe_match"]: + if "cpe23Uri" not in cpeuri: + continue + if cpeuri["vulnerable"]: + query, version_info = self.get_cpe_info(cpeuri) + if query != {}: + query["id"] = hashlib.sha1( + cpeuri["cpe23Uri"].encode("utf-8") + + version_info.encode("utf-8") + ).hexdigest() + cpe_info = getCPEVersionInformation(query) + if cpe_info: + if cpe_info["cpe_name"]: + for vulnerable_version in cpe_info[ + "cpe_name" + ]: + cve = self.add_if_missing( + cve, + "vulnerable_product", + vulnerable_version["cpe23Uri"], + ) + cve = self.add_if_missing( + cve, + "vulnerable_configuration", + vulnerable_version["cpe23Uri"], + ) + cve = self.add_if_missing( + cve, + "vulnerable_configuration_stems", + self.stem( + vulnerable_version[ + "cpe23Uri" + ] + ), + ) + ( + vendor, + product, + ) = self.get_vendor_product( + vulnerable_version["cpe23Uri"] + ) + cve = self.add_if_missing( + cve, "vendors", vendor + ) + cve = self.add_if_missing( + cve, "products", product + ) + cve = self.add_if_missing( + cve, + "vulnerable_product_stems", + self.stem( + vulnerable_version[ + "cpe23Uri" + ] + ), + ) + else: + cve = self.add_if_missing( + cve, + "vulnerable_product", + cpeuri["cpe23Uri"], + ) + cve = self.add_if_missing( + cve, + "vulnerable_configuration", + cpeuri["cpe23Uri"], + ) + cve = self.add_if_missing( + cve, + "vulnerable_configuration_stems", + self.stem(cpeuri["cpe23Uri"]), + ) + ( + vendor, + product, + ) = self.get_vendor_product( + cpeuri["cpe23Uri"] + ) + cve = self.add_if_missing( + cve, "vendors", vendor + ) + cve = self.add_if_missing( + cve, "products", product + ) + cve = self.add_if_missing( + cve, + "vulnerable_product_stems", + self.stem(cpeuri["cpe23Uri"]), + ) + else: + # If the cpe_match did not have any of the version start/end modifiers, + # add the CPE string as it is. + if "cpe23Uri" not in cpeuri: + continue + cve = self.add_if_missing( + cve, + "vulnerable_product", + cpeuri["cpe23Uri"], + ) + cve = self.add_if_missing( + cve, + "vulnerable_configuration", + cpeuri["cpe23Uri"], + ) + cve = self.add_if_missing( + cve, + "vulnerable_configuration_stems", + self.stem(cpeuri["cpe23Uri"]), + ) + vendor, product = self.get_vendor_product( + cpeuri["cpe23Uri"] + ) + cve = self.add_if_missing( + cve, "vendors", vendor + ) + cve = self.add_if_missing( + cve, "products", product + ) + cve = self.add_if_missing( + cve, + "vulnerable_product_stems", + self.stem(cpeuri["cpe23Uri"]), + ) + else: + if "cpe23Uri" not in cpeuri: + continue + cve = self.add_if_missing( + cve, + "vulnerable_configuration", + cpeuri["cpe23Uri"], + ) + cve = self.add_if_missing( + cve, + "vulnerable_configuration_stems", + self.stem(cpeuri["cpe23Uri"]), + ) + if "problemtype" in item["cve"]: + for problem in item["cve"]["problemtype"]["problemtype_data"]: + for cwe in problem[ + "description" + ]: # NVD JSON not clear if we can get more than one CWE per CVE (until we take the last one) - + # NVD-CWE-Other??? list? + if cwe["lang"] == "en": + cve["cwe"] = cwe["value"] + if not ("cwe" in cve): + cve["cwe"] = defaultvalue["cwe"] + else: + cve["cwe"] = defaultvalue["cwe"] + cve["vulnerable_configuration_cpe_2_2"] = [] + return cve + + def process_item(self, item): + cve = self.process_cve_item(item) + + if cve is not None: + if self.is_update: + self.queue.put( + DatabaseAction( + action=DatabaseAction.actions.UpdateOne, + collection=self.feed_type.lower(), + doc=cve, + ) + ) + else: + self.queue.put( + DatabaseAction( + action=DatabaseAction.actions.InsertOne, + collection=self.feed_type.lower(), + doc=cve, + ) + ) + + def update(self): + self.logger.info("CVE database update started") + + # if collection is non-existent; assume it's not an update + if "cves" not in getTableNames(): + self.is_update = False + + self.process_downloads( + [self.feed_url + self.modfile, self.feed_url + self.recfile], + collection=self.feed_type.lower(), + ) + + self.logger.info("Finished CVE database update") + + return self.last_modified + + def populate(self): + urls = [] + + self.logger.info("CVE database population started") + + self.logger.info( + "Starting CVE database population starting from year: {}".format( + cveStartYear + ) + ) + + self.is_update = False + + self.queue.clear() + + dropCollection("cves") + + DatabaseIndexer().create_indexes(collection="cves") + + for x in self.get_cve_year_range(): + getfile = file_prefix + str(x) + file_suffix + + urls.append(self.feed_url + getfile) + + self.process_downloads(urls, collection=self.feed_type.lower()) + + self.logger.info("Finished CVE database population") + + return self.last_modified + + +class VIADownloads(JSONFileHandler): + def __init__(self): + self.feed_type = "VIA4" + self.prefix = "cves" + super().__init__(self.feed_type, self.prefix) + + self.feed_url = Configuration.getFeedURL(self.feed_type.lower()) + + self.logger = logging.getLogger("VIADownloads") + + def file_to_queue(self, file_tuple): + + working_dir, filename = file_tuple + + for cve in self.ijson_handler.fetch(filename=filename, prefix=self.prefix): + x = 0 + for key, val in cve.items(): + entry_dict = {"id": key} + entry_dict.update(val) + self.process_item(item=entry_dict) + x += 1 + + self.logger.debug("Processed {} items from file: {}".format(x, filename)) + + with open(filename, "rb") as input_file: + data = json.loads(input_file.read().decode("utf-8")) + + setColInfo("via4", "sources", data["metadata"]["sources"]) + setColInfo("via4", "searchables", data["metadata"]["searchables"]) + + self.logger.debug("Processed metadata from file: {}".format(filename)) + + try: + self.logger.debug("Removing working dir: {}".format(working_dir)) + shutil.rmtree(working_dir) + except Exception as err: + self.logger.error( + "Failed to remove working dir; error produced: {}".format(err) + ) + + def process_item(self, item): + + if self.is_update: + self.queue.put( + DatabaseAction( + action=DatabaseAction.actions.UpdateOne, + collection=self.feed_type.lower(), + doc=item, + ) + ) + else: + self.queue.put( + DatabaseAction( + action=DatabaseAction.actions.InsertOne, + collection=self.feed_type.lower(), + doc=item, + ) + ) + + def update(self, **kwargs): + self.logger.info("VIA4 database update started") + + # if collection is non-existent; assume it's not an update + if self.feed_type.lower() not in getTableNames(): + self.is_update = False + + self.process_downloads([self.feed_url], collection=self.feed_type.lower()) + + self.logger.info("Finished VIA4 database update") + + return self.last_modified + + def populate(self, **kwargs): + self.is_update = False + self.queue.clear() + return self.update() + + +class CAPECDownloads(XMLFileHandler): + def __init__(self): + self.feed_type = "CAPEC" + super().__init__(self.feed_type) + + self.feed_url = Configuration.getFeedURL(self.feed_type.lower()) + + self.logger = logging.getLogger("CAPECDownloads") + + # make parser + self.parser = make_parser() + self.ch = CapecHandler() + self.parser.setContentHandler(self.ch) + + def file_to_queue(self, file_tuple): + + working_dir, filename = file_tuple + + self.parser.parse(filename) + x = 0 + for attack in self.ch.capec: + self.process_item(attack) + x += 1 + + self.logger.debug("Processed {} entries from file: {}".format(x, filename)) + + try: + self.logger.debug("Removing working dir: {}".format(working_dir)) + shutil.rmtree(working_dir) + except Exception as err: + self.logger.error( + "Failed to remove working dir; error produced: {}".format(err) + ) + + def update(self, **kwargs): + self.logger.info("CAPEC database update started") + + # if collection is non-existent; assume it's not an update + if self.feed_type.lower() not in getTableNames(): + self.is_update = False + + self.process_downloads([self.feed_url], collection=self.feed_type.lower()) + + self.logger.info("Finished CAPEC database update") + + return self.last_modified + + def populate(self, **kwargs): + self.is_update = False + self.queue.clear() + return self.update() + + +class CWEDownloads(XMLFileHandler): + def __init__(self): + self.feed_type = "CWE" + super().__init__(self.feed_type) + + self.feed_url = Configuration.getFeedURL(self.feed_type.lower()) + + self.logger = logging.getLogger("CWEDownloads") + + # make parser + self.parser = make_parser() + self.ch = CWEHandler() + self.parser.setContentHandler(self.ch) + + def file_to_queue(self, file_tuple): + + working_dir, filename = file_tuple + + self.parser.parse(filename) + x = 0 + for cwe in self.ch.cwe: + try: + cwe["related_weaknesses"] = list(set(cwe["related_weaknesses"])) + except KeyError: + pass + self.process_item(cwe) + x += 1 + + self.logger.debug("Processed {} entries from file: {}".format(x, filename)) + + try: + self.logger.debug("Removing working dir: {}".format(working_dir)) + shutil.rmtree(working_dir) + except Exception as err: + self.logger.error( + "Failed to remove working dir; error produced: {}".format(err) + ) + + def update(self, **kwargs): + self.logger.info("CWE database update started") + + # if collection is non-existent; assume it's not an update + if self.feed_type.lower() not in getTableNames(): + self.is_update = False + + self.process_downloads([self.feed_url], collection=self.feed_type.lower()) + + self.logger.info("Finished CWE database update") + + return self.last_modified + + def populate(self, **kwargs): + self.is_update = False + self.queue.clear() + return self.update() + + +MongoUniqueIndex = namedtuple("MongoUniqueIndex", "index name unique weights") +MongoAddIndex = namedtuple("MongoAddIndex", "index name weights") + + +class DatabaseIndexer(object): + def __init__(self): + + self.indexes = { + "cpe": [ + MongoUniqueIndex( + index=[("id", ASCENDING)], + name="id", + unique=True, + weights={"id": 10}, + ), + MongoAddIndex( + index=[("vendor", ASCENDING)], name="vendor", weights={"vendor": 5} + ), + MongoAddIndex( + index=[("product", ASCENDING)], + name="product", + weights={"product": 3}, + ), + ], + "cpeother": [ + MongoUniqueIndex( + index=[("id", ASCENDING)], + name="id", + unique=True, + weights={"id": 10}, + ) + ], + "cves": [ + MongoAddIndex(index=[("id", ASCENDING)], name="id", weights={"id": 10}), + MongoAddIndex( + index=[("vulnerable_configuration", ASCENDING)], + name="vulnerable_configuration", + weights={"vulnerable_configuration": 3}, + ), + MongoAddIndex( + index=[("vulnerable_product", ASCENDING)], + name="vulnerable_product", + weights={"vulnerable_product": 3}, + ), + MongoAddIndex( + index=[("Modified", ASCENDING)], + name="Modified", + weights={"Modified": 3}, + ), + MongoAddIndex( + index=[("Published", ASCENDING)], + name="Published", + weights={"Published": 3}, + ), + MongoAddIndex( + index=[("last-modified", ASCENDING)], + name="last-modified", + weights={"last-modified": 3}, + ), + MongoAddIndex( + index=[("cvss", ASCENDING)], name="cvss", weights={"cvss": 5} + ), + MongoAddIndex( + index=[("cvss3", ASCENDING)], name="cvss3", weights={"cvss3": 5} + ), + MongoAddIndex( + index=[("summary", TEXT)], name="summary", weights={"summary": 5} + ), + MongoAddIndex( + index=[("vendors", ASCENDING)], + name="vendors", + weights={"vendors": 5}, + ), + MongoAddIndex( + index=[("products", ASCENDING)], + name="products", + weights={"products": 5}, + ), + MongoAddIndex( + index=[("vulnerable_product_stems", ASCENDING)], + name="vulnerable_product_stems", + weights={"vulnerable_product_stems": 5}, + ), + MongoAddIndex( + index=[("vulnerable_configuration_stems", ASCENDING)], + name="vulnerable_configuration_stems", + weights={"vulnerable_configuration_stems": 5}, + ), + ], + "via4": [ + MongoAddIndex(index=[("id", ASCENDING)], name="id", weights={"id": 10}) + ], + "mgmt_whitelist": [ + MongoAddIndex(index=[("id", ASCENDING)], name="id", weights={"id": 10}) + ], + "mgmt_blacklist": [ + MongoAddIndex(index=[("id", ASCENDING)], name="id", weights={"id": 10}) + ], + "capec": [ + MongoAddIndex( + index=[("related_weakness", ASCENDING)], + name="related_weakness", + weights={"related_weakness": 10}, + ) + ], + } + + self.logger = logging.getLogger("DatabaseIndexer") + + def create_indexes(self, collection=None): + + if collection is not None: + try: + for each in self.indexes[collection]: + if isinstance(each, MongoUniqueIndex): + self.setIndex( + collection, + each.index, + name=each.name, + unique=each.unique, + weights=each.weights, + ) + elif isinstance(each, MongoAddIndex): + self.setIndex( + collection, each.index, name=each.name, weights=each.weights + ) + except KeyError: + # no specific index given, continue + self.logger.warning( + "Could not find the requested collection: {}, skipping...".format( + collection + ) + ) + pass + + else: + for index in self.iter_indexes(): + self.setIndex(index[0], index[1]) + + for collection in self.indexes.keys(): + for each in self.indexes[collection]: + if isinstance(each, MongoUniqueIndex): + self.setIndex( + collection, + each.index, + name=each.name, + unique=each.unique, + weights=each.weights, + ) + elif isinstance(each, MongoAddIndex): + self.setIndex( + collection, each.index, name=each.name, weights=each.weights + ) + + def iter_indexes(self): + for each in self.get_via4_indexes(): + yield each + + def get_via4_indexes(self): + via4 = getInfo("via4") + result = [] + if via4: + for index in via4.get("searchables", []): + result.append(("via4", index)) + return result + + def setIndex(self, col, field, **kwargs): + try: + ensureIndex(col, field, **kwargs) + self.logger.info("Success to create index %s on %s" % (field, col)) + except Exception as e: + self.logger.error("Failed to create index %s on %s: %s" % (col, field, e)) diff --git a/CveXplore/update/Toolkit.py b/CveXplore/update/Toolkit.py new file mode 100644 index 00000000..1813ac87 --- /dev/null +++ b/CveXplore/update/Toolkit.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Toolkit for functions between scripts +# +# Software is free software released under the "GNU Affero General Public License v3.0" +# +# Copyright (c) 2014-2018 Pieter-Jan Moreels - pieterjan.moreels@gmail.com + +import re + +import dateutil.parser +from dateutil import tz + + +# Note of warning: CPEs like cpe:/o:microsoft:windows_8:-:-:x64 are given to us by Mitre +# x64 will be parsed as Edition in this case, not Architecture +def toStringFormattedCPE(cpe, autofill=False): + cpe = cpe.strip() + if not cpe.startswith("cpe:2.3:"): + if not cpe.startswith("cpe:/"): + return False + cpe = cpe.replace("cpe:/", "cpe:2.3:") + cpe = cpe.replace("::", ":-:") + cpe = cpe.replace("~-", "~") + cpe = cpe.replace("~", ":-:") + cpe = cpe.replace("::", ":") + cpe = cpe.strip(":-") + cpe = unquote(cpe) + if autofill: + e = cpe.split(":") + for x in range(0, 13 - len(e)): + cpe += ":-" + return cpe + + +# Note of warning: Old CPE's can come in different formats, and are not uniform. Possibilities are: +# cpe:/a:7-zip:7-zip:4.65::~~~~x64~ +# cpe:/a:7-zip:7-zip:4.65:-:~~~~x64~ +# cpe:/a:7-zip:7-zip:4.65:-:~-~-~-~x64~ +def toOldCPE(cpe): + cpe = cpe.strip() + if not cpe.startswith("cpe:/"): + if not cpe.startswith("cpe:2.3:"): + return False + cpe = cpe.replace("cpe:2.3:", "") + parts = cpe.split(":") + next = [] + first = "cpe:/" + ":".join(parts[:5]) + last = parts[5:] + if last: + for x in last: + next.append("~") if x == "-" else next.append(x) + if "~" in next: + pad(next, 6, "~") + cpe = "%s:%s" % (first, "".join(next)) + cpe = cpe.replace(":-:", "::") + cpe = cpe.strip(":") + return cpe + + +def pad(seq, target_length, padding=None): + length = len(seq) + if length > target_length: + return seq + seq.extend([padding] * (target_length - length)) + return seq + + +def currentTime(utc): + timezone = tz.tzlocal() + utc = dateutil.parser.parse(utc) + output = utc.astimezone(timezone) + output = output.strftime("%d-%m-%Y - %H:%M") + return output + + +def isURL(string): + urlTypes = [re.escape(x) for x in ["http://", "https://", "www."]] + return re.match("^(" + "|".join(urlTypes) + ")", string) + + +def vFeedName(string): + string = string.replace("map_", "") + string = string.replace("cve_", "") + return string.title() + + +def mergeSearchResults(database, plugins): + if "errors" in database: + results = {"data": [], "errors": database["errors"]} + else: + results = {"data": []} + + data = [] + data.extend(database["data"]) + data.extend(plugins["data"]) + for cve in data: + if not any(cve["id"] == entry["id"] for entry in results["data"]): + results["data"].append(cve) + return results + + +def tk_compile(regexes): + if type(regexes) not in [list, tuple]: + regexes = [regexes] + r = [] + for rule in regexes: + r.append(re.compile(rule)) + return r + + +# Convert cpe2.2 url encoded to cpe2.3 char escaped +# cpe:2.3:o:cisco:ios:12.2%281%29 to cpe:2.3:o:cisco:ios:12.2\(1\) +def unquote(cpe): + return re.compile("%([0-9a-fA-F]{2})", re.M).sub( + lambda m: "\\" + chr(int(m.group(1), 16)), cpe + ) + + +# Generates a human readable title from a CPE 2.3 string +def generate_title(cpe): + title = "" + + cpe_split = cpe.split(":") + # Do a very basic test to see if the CPE is valid + if len(cpe_split) == 13: + + # Combine vendor, product and version + title = " ".join(cpe_split[3:6]) + + # If "other" is specified, add it to the title + if cpe_split[12] != "*": + title += cpe_split[12] + + # Capitilize each word + title = title.title() + + # If the target_sw is defined, add "for " to title + if cpe_split[10] != "*": + title += " for " + cpe_split[10] + + # In CPE 2.3 spaces are replaced with underscores. Undo it + title = title.replace("_", " ") + + # Special characters are escaped with \. Undo it + title = title.replace("\\", "") + + return title diff --git a/CveXplore/update/XMLFileHandler.py b/CveXplore/update/XMLFileHandler.py new file mode 100644 index 00000000..9b4ba3c5 --- /dev/null +++ b/CveXplore/update/XMLFileHandler.py @@ -0,0 +1,45 @@ +from abc import abstractmethod + +from .DownloadHandler import DownloadHandler +from .db_action import DatabaseAction + + +class XMLFileHandler(DownloadHandler): + def __init__(self, feed_type): + super().__init__(feed_type) + self.is_update = True + + def __repr__(self): + """ return string representation of object """ + return "<< XMLFileHandler:{} >>".format(self.feed_type) + + def process_item(self, item): + + if self.is_update: + self.queue.put( + DatabaseAction( + action=DatabaseAction.actions.UpdateOne, + collection=self.feed_type.lower(), + doc=item, + ) + ) + else: + self.queue.put( + DatabaseAction( + action=DatabaseAction.actions.InsertOne, + collection=self.feed_type.lower(), + doc=item, + ) + ) + + @abstractmethod + def file_to_queue(self, *args): + raise NotImplementedError + + @abstractmethod + def update(self, **kwargs): + raise NotImplementedError + + @abstractmethod + def populate(self, **kwargs): + raise NotImplementedError diff --git a/CveXplore/update/__init__.py b/CveXplore/update/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/CveXplore/update/__init__.py @@ -0,0 +1 @@ + diff --git a/CveXplore/update/content_handlers.py b/CveXplore/update/content_handlers.py new file mode 100644 index 00000000..54797ae0 --- /dev/null +++ b/CveXplore/update/content_handlers.py @@ -0,0 +1,455 @@ +from collections import defaultdict +from xml.sax.handler import ContentHandler + + +class CapecHandler(ContentHandler): + def __init__(self): + self.capec = [] + self.Attack_Pattern_Catalog_tag = False + self.Attack_Patterns_tag = False + self.Attack_Pattern_tag = False + self.Attack_step_tag = False + self.Description_tag = False + self.Text_tag = False + self.Prerequisites_tag = False + self.Prerequisite_tag = False + self.Mitigations_tag = False + self.Mitigation_tag = False + self.Related_Weaknesses_tag = False + self.Related_Weakness_tag = False + self.CWE_ID_tag = False + self.Related_Attack_Patterns = False + self.Taxonomy_Mappings = False + self.Taxonomy_Mapping = False + self.Likelihood_Of_Attack = False + self.Typical_Severity = False + + self.Execution_Flow = False + self.Step = True + self.Phase = True + self.Attack_Description = True + self.Technique = True + + self.entry_id = False + self.entry_name = False + + self.tag = False + + self.id = "" + self.name = "" + + self.Summary_ch = "" + self.Prerequisite_ch = "" + self.Mitigation_ch = "" + self.CWE_ID_ch = "" + self.entry_id_ch = "" + self.entry_name_ch = "" + + self.taxonomy_name = "" + self.step_name = "" + + self.Step_ch = "" + self.Phase_ch = "" + self.Attack_Description_ch = "" + self.Technique_ch = "" + self.Likelihood_Of_Attack_ch = "" + self.Typical_Severity_ch = "" + self.loa = "" + self.ts = "" + + self.Summary = [] + self.Prerequisite = [] + self.Solution_or_Mitigation = [] + self.Related_Weakness = [] + self.Related_AttackPatterns = [] + self.techniques = [] + + self.taxonomy_mapping = defaultdict(dict) + + self.execution_flow = defaultdict(dict) + + def startElement(self, name, attrs): + + if name == "Attack_Pattern_Catalog": + self.Attack_Pattern_Catalog_tag = True + if name == "Attack_Patterns" and self.Attack_Pattern_Catalog_tag: + self.Attack_Patterns_tag = True + if name == "Attack_Pattern" and self.Attack_Patterns_tag: + self.Attack_Pattern_tag = True + + if self.Attack_Pattern_tag: + self.tag = name + if self.tag == "Attack_Pattern": + self.id = attrs.getValue("ID") + self.name = attrs.getValue("Name") + + if self.tag == "Related_Attack_Patterns": + self.Related_Attack_Patterns = True + + if self.tag == "Related_Attack_Pattern" and self.Related_Attack_Patterns: + self.Related_AttackPatterns.append(attrs.get("CAPEC_ID")) + + if self.tag == "Taxonomy_Mappings": + self.Taxonomy_Mappings = True + + if self.tag == "Taxonomy_Mapping" and self.Taxonomy_Mappings: + self.Taxonomy_Mapping = True + self.taxonomy_name = attrs.get("Taxonomy_Name") + + if self.tag == "Entry_ID" and self.Taxonomy_Mappings: + self.entry_id = True + self.entry_id_ch = "" + + if self.tag == "Entry_Name" and self.Taxonomy_Mappings: + self.entry_name = True + self.entry_name_ch = "" + + if self.tag == "Execution_Flow": + self.Execution_Flow = True + + if self.tag == "Attack_Step" and self.Execution_Flow: + self.Attack_step_tag = True + + if self.tag == "Step" and self.Attack_step_tag: + self.Step = True + self.Step_ch = "" + + if self.tag == "Phase" and self.Attack_step_tag: + self.Phase = True + self.Phase_ch = "" + + if self.tag == "Description" and self.Attack_step_tag: + self.Attack_Description = True + self.Attack_Description_ch = "" + + if self.tag == "Technique" and self.Attack_step_tag: + self.Technique = True + self.Technique_ch = "" + + if self.tag == "Description" and not self.Attack_step_tag: + self.Description_tag = True + self.Summary_ch = "" + + if self.tag == "Prerequisites": + self.Prerequisites_tag = True + if name == "Prerequisite" and self.Prerequisites_tag: + self.Prerequisite_tag = True + self.Prerequisite_ch = "" + + if self.tag == "Mitigations": + self.Mitigations_tag = True + if name == "Mitigation" and self.Mitigations_tag: + self.Mitigation_tag = True + if name == "xhtml:p" and self.Mitigation_tag: + self.Text_tag = True + self.Mitigation_ch = "" + + if self.tag == "Related_Weaknesses": + self.Related_Weaknesses_tag = True + if name == "Related_Weakness" and self.Related_Weaknesses_tag: + self.Related_Weakness.append(attrs.getValue("CWE_ID")) + + if self.tag == "Likelihood_Of_Attack": + self.Likelihood_Of_Attack = True + self.Likelihood_Of_Attack_ch = "" + + if self.tag == "Typical_Severity": + self.Typical_Severity = True + self.Typical_Severity_ch = "" + + def characters(self, ch): + if self.Description_tag: + self.Summary_ch += ch + + if self.Prerequisite_tag: + self.Prerequisite_ch += ch + + if self.Text_tag: + if self.Mitigation_tag: + self.Mitigation_ch += ch + + if self.entry_id: + self.entry_id_ch += ch + + if self.entry_name: + self.entry_name_ch += ch + + if self.Step: + self.Step_ch += ch + + if self.Phase: + self.Phase_ch += ch + + if self.Attack_Description: + self.Attack_Description_ch += ch + + if self.Technique: + self.Technique_ch += ch + + if self.Likelihood_Of_Attack: + self.Likelihood_Of_Attack_ch += ch + + if self.Typical_Severity: + self.Typical_Severity_ch += ch + + def endElement(self, name): + if name == "Description" and not self.Attack_step_tag: + self.Summary.append(self.Summary_ch.rstrip()) + if self.Summary_ch != "": + self.Summary_ch = "" + self.Description_tag = False + + if name == "Entry_ID": + self.entry_id = False + + if name == "Entry_Name": + self.entry_name = False + + entry_id = self.entry_id_ch.rstrip() + + cut_entry = entry_id.split(".") + + url = "" + + if self.taxonomy_name == "ATTACK": + if len(cut_entry) == 1: + # no subtechnique use plain entry_id + url = "https://attack.mitre.org/techniques/T{}".format(entry_id) + else: + # attack with subtechniques use cut_entry + url = "https://attack.mitre.org/techniques/T{}/{}".format( + cut_entry[0], cut_entry[1] + ) + + elif self.taxonomy_name == "WASC": + + if "/" in self.entry_name_ch: + url = "http://projects.webappsec.org/{}".format( + self.entry_name_ch.replace("/", " and ").replace(" ", "-") + ) + else: + url = "http://projects.webappsec.org/{}".format( + self.entry_name_ch.replace(" ", "-") + ) + + elif self.taxonomy_name == "OWASP Attacks": + entry_id = "Link" + + url = "https://owasp.org/www-community/attacks/{}".format( + self.entry_name_ch.replace(" ", "_") + ) + + self.taxonomy_mapping[self.taxonomy_name][ + self.entry_id_ch.rstrip().replace(".", "_") + ] = { + "Entry_ID": entry_id, + "Entry_Name": self.entry_name_ch.rstrip(), + "URL": url, + } + + if self.entry_id_ch != "": + self.entry_id_ch = "" + + if self.entry_name_ch != "": + self.entry_name_ch = "" + + if name == "Taxonomy_Mappings": + self.Taxonomy_Mappings = False + + if name == "Taxonomy_Mapping": + self.Taxonomy_Mapping = False + + if name == "Step": + self.step_name = self.Step_ch.rstrip() + self.Step = False + + if name == "Phase": + self.Phase = False + + if name == "Description" and self.Attack_step_tag: + self.Attack_Description = False + + self.execution_flow[self.step_name] = { + "Phase": self.Phase_ch.rstrip(), + "Description": self.Attack_Description_ch.rstrip(), + "Techniques": [], + } + + if self.Step_ch != "": + self.Step_ch = "" + + if self.Phase_ch != "": + self.Phase_ch = "" + + if self.Attack_Description_ch != "": + self.Attack_Description_ch = "" + + if name == "Technique" and self.Attack_step_tag: + if self.Technique_ch != "": + self.execution_flow[self.step_name]["Techniques"].append( + self.Technique_ch.rstrip() + ) + self.Technique_ch = "" + self.Technique = False + + if name == "Attack_Step": + self.Attack_step_tag = False + + if name == "Execution_Flow": + self.Execution_Flow = False + + if name == "Prerequisite": + if self.Prerequisite_ch != "": + self.Prerequisite.append(self.Prerequisite_ch.rstrip()) + self.Prerequisite_tag = False + if name == "Mitigation": + if self.Mitigation_ch != "": + self.Solution_or_Mitigation.append(self.Mitigation_ch.rstrip()) + self.Mitigation_ch = "" + self.Mitigation_tag = False + + if name == "Prerequisites": + self.Prerequisites_tag = False + if name == "Mitigations": + self.Mitigations_tag = False + if name == "Related_Weaknesses": + self.Related_Weaknesses_tag = False + + if name == "Related_Attack_Patterns": + self.Related_Attack_Patterns = False + + if name == "Likelihood_Of_Attack": + self.Likelihood_Of_Attack = False + self.loa = self.Likelihood_Of_Attack_ch.rstrip() + self.Likelihood_Of_Attack_ch = "" + + if name == "Typical_Severity": + self.Typical_Severity = False + self.ts = self.Typical_Severity_ch.rstrip() + self.Typical_Severity_ch = "" + + if name == "Attack_Pattern": + if not self.name.startswith("DEPRECATED"): + self.capec.append( + { + "name": self.name, + "id": self.id, + "summary": "\n".join(self.Summary), + "prerequisites": " ".join(self.Prerequisite), + "solutions": " ".join(self.Solution_or_Mitigation), + "related_capecs": sorted(self.Related_AttackPatterns), + "related_weakness": sorted(self.Related_Weakness), + "taxonomy": dict(self.taxonomy_mapping), + "execution_flow": dict(self.execution_flow), + "loa": self.loa, + "typical_severity": self.ts, + } + ) + self.Summary = [] + self.Prerequisite = [] + self.Solution_or_Mitigation = [] + self.Related_Weakness = [] + self.Related_AttackPatterns = [] + self.techniques = [] + + self.taxonomy_mapping = defaultdict(dict) + + self.execution_flow = defaultdict(dict) + + self.Attack_Pattern_tag = False + if name == "Attack_Patterns": + self.Attack_Patterns_tag = False + if name == "Attack_Pattern_Catalog": + self.Attack_Pattern_Catalog_tag = False + + +class CWEHandler(ContentHandler): + def __init__(self): + self.cwe = [] + self.description_tag = False + self.category_tag = False + self.weakness_tag = False + self.weakness_relationships_tag = False + self.category_relationships_tag = False + + def startElement(self, name, attrs): + + if name == "Weakness": + self.weakness_tag = True + self.statement = "" + self.weaknessabs = attrs.get("Abstraction") + self.name = attrs.get("Name") + self.idname = attrs.get("ID") + self.status = attrs.get("Status") + if not self.name.startswith("DEPRECATED"): + self.cwe.append( + { + "name": self.name, + "id": self.idname, + "status": self.status, + "weaknessabs": self.weaknessabs, + } + ) + + elif name == "Category": + self.category_tag = True + self.category_name = attrs.get("Name") + self.category_id = attrs.get("ID") + self.category_status = attrs.get("Status") + if not self.category_name.startswith("DEPRECATED"): + self.cwe.append( + { + "name": self.category_name, + "id": self.category_id, + "status": self.category_status, + "weaknessabs": "Category", + } + ) + + elif name == "Description" and self.weakness_tag: + self.description_tag = True + self.description = "" + + elif name == "Summary" and self.category_tag: + self.description_tag = True + self.description = "" + + elif name == "Relationships" and self.category_tag: + self.category_relationships_tag = True + self.relationships = [] + + elif name == "Related_Weaknesses" and self.weakness_tag: + self.weakness_relationships_tag = True + self.relationships = [] + + elif name == "Related_Weakness" and self.weakness_relationships_tag: + self.relationships.append(attrs.get("CWE_ID")) + + elif name == "Has_Member" and self.category_relationships_tag: + self.relationships.append(attrs.get("CWE_ID")) + + def characters(self, ch): + if self.description_tag: + self.description += ch.replace(" ", "") + + def endElement(self, name): + if name == "Description" and self.weakness_tag: + self.description_tag = False + self.description = self.description + self.description + self.cwe[-1]["Description"] = self.description.replace("\n", "") + if name == "Summary" and self.category_tag: + self.description_tag = False + self.description = self.description + self.description + self.cwe[-1]["Description"] = self.description.replace("\n", "") + elif name == "Weakness" and self.weakness_tag: + self.weakness_tag = False + elif name == "Category" and self.category_tag: + self.category_tag = False + + elif name == "Related_Weaknesses" and self.weakness_tag: + self.weakness_relationships_tag = False + self.cwe[-1]["related_weaknesses"] = self.relationships + + elif name == "Relationships" and self.category_tag: + self.category_relationships_tag = False + self.cwe[-1]["relationships"] = self.relationships diff --git a/CveXplore/update/db_action.py b/CveXplore/update/db_action.py new file mode 100644 index 00000000..d26fc712 --- /dev/null +++ b/CveXplore/update/db_action.py @@ -0,0 +1,21 @@ +import collections + +from pymongo import InsertOne, UpdateOne + + +class DatabaseAction(object): + + actions = collections.namedtuple("Actions", "InsertOne UpdateOne")(0, 1) + + def __init__(self, action, collection, doc): + + self.action = action + self.collection = collection + self.doc = doc + + @property + def entry(self): + if self.action == self.actions.InsertOne: + return InsertOne(self.doc) + elif self.action == self.actions.UpdateOne: + return UpdateOne({"id": self.doc["id"]}, {"$set": self.doc}, upsert=True) diff --git a/CveXplore/update/main_updater.py b/CveXplore/update/main_updater.py new file mode 100644 index 00000000..732bab63 --- /dev/null +++ b/CveXplore/update/main_updater.py @@ -0,0 +1,44 @@ +from CveXplore.update.Sources_process import ( + CPEDownloads, + CVEDownloads, + CWEDownloads, + CAPECDownloads, + VIADownloads, + DatabaseIndexer, +) + + +class MainUpdater(object): + def __init__(self, repopulate=False): + + self.repopulate = repopulate + + self.sources = [ + {"name": "cpe", "updater": CPEDownloads}, + {"name": "cve", "updater": CVEDownloads}, + {"name": "cwe", "updater": CWEDownloads}, + {"name": "capec", "updater": CAPECDownloads}, + {"name": "via4", "updater": VIADownloads}, + ] + + self.posts = [{"name": "ensureindex", "updater": DatabaseIndexer}] + + def update(self): + + for source in self.sources: + up = source["updater"]() + up.update() + + for post in self.posts: + indexer = post["updater"]() + indexer.create_indexes() + + def populate(self): + + for source in self.sources: + populator = source["updater"]() + populator.populate() + + for post in self.posts: + indexer = post["updater"]() + indexer.create_indexes() diff --git a/CveXplore/update/redis_q.py b/CveXplore/update/redis_q.py new file mode 100644 index 00000000..4dd147a1 --- /dev/null +++ b/CveXplore/update/redis_q.py @@ -0,0 +1,359 @@ +import threading +from collections import deque + + +from queue import Empty, Full, Queue +from time import time + +import jsonpickle + +from .Config import Configuration +from .db_action import DatabaseAction + + +class RedisQueue(object): + def __init__(self, name, serializer=jsonpickle, namespace="queue"): + self.__db = Configuration.getRedisQConnection() + self.serializer = serializer + self._key = "{}:{}".format(name, namespace) + + def __len__(self): + return self.qsize() + + def __repr__(self): + return "<< RedisQueue:{} >>".format(self.key) + + def __iter__(self): + return self + + def __next__(self): + item = self.get(timeout=1) + if item is not None: + if isinstance(item, DatabaseAction): + item = item.entry + return item + else: + raise StopIteration + + @property + def key(self): + return self._key + + def get_full_list(self): + + entries = self.__db.lrange(self.key, 0, -1) + + self.__db.delete(self.key) + + return [self.serializer.decode(entry) for entry in entries] + + def clear(self): + """Clear the queue of all messages, deleting the Redis key.""" + self.__db.delete(self.key) + + def qsize(self): + """ + Return size of the queue + + :return: + :rtype: + """ + return self.__db.llen(self.key) + + def get(self, block=False, timeout=None): + """ + Return an item from the queue. + + :param block: Whether or not to wait for item to be available; defaults to False + :type block: bool + :param timeout: Time to wait for item to be available in the queue; defaults to None + :type timeout: int + :return: Item popped from list + :rtype: * + """ + if block: + if timeout is None: + timeout = 0 + item = self.__db.blpop(self.key, timeout=timeout) + if item is not None: + item = item[1] + else: + item = self.__db.lpop(self.key) + if item is not None and self.serializer is not None: + item = self.serializer.decode(item) + return item + + def put(self, *items): + """ + Put one or more items onto the queue. + + Example: + + q.put("my item") + q.put("another item") + + To put messages onto the queue in bulk, which can be significantly + faster if you have a large number of messages: + + q.put("my item", "another item", "third item") + """ + if self.serializer is not None: + items = map(self.serializer.encode, items) + self.__db.rpush(self.key, *items) + + +# class CveXploreQueue(Queue): +# +# def __init__(self, name, maxsize=0, serializer=jsonpickle): +# super().__init__(maxsize) +# self.name = name +# +# self.serializer = serializer +# +# def __repr__(self): +# return "<< CveXploreQueue:{} >>".format(self.name) +# +# # Put a new item in the queue +# def _put(self, item): +# self.queue.append(self.serializer.encode(item)) +# +# # Get an item from the queue +# def _get(self): +# item = self.serializer.decode(self.queue.popleft()) +# if isinstance(item, DatabaseAction): +# item = item.entry +# return item +# +# def getall(self, block=True, timeout=None): +# with self.not_empty: +# if self.closed: +# raise QueueClosed() +# if not block: +# if not self._qsize(): +# raise Empty +# elif timeout is None: +# while not self._qsize() and not self.closed: +# self.not_empty.wait() +# elif timeout < 0: +# raise ValueError("'timeout' must be a non-negative number") +# else: +# endtime = time() + timeout +# while not self._qsize() and not self.closed: +# remaining = endtime - time() +# if remaining <= 0.0: +# raise Empty +# self.not_empty.wait(remaining) +# if self.closed: +# raise QueueClosed() +# items = list(self.queue) +# items = list(map(self.serializer.decode, items)) +# self.queue.clear() +# self.not_full.notify() +# return items +# +# def iter_queue(self): +# with self.not_empty: +# item = self.get(timeout=1) +# if item is not None: +# if isinstance(item, DatabaseAction): +# item = item.entry +# yield item +# else: +# self.not_full.notify() +# raise StopIteration +# +# def clear(self): +# with self.not_empty: +# self.queue.clear() +# self.not_full.notify() + + +class QueueClosed(Exception): + pass + + +class CveXploreQueue(object): + + def __init__(self, name, maxsize=0, serializer=jsonpickle): + self.name = name + self.maxsize = maxsize + self._init(maxsize) + + self.mutex = threading.Lock() + + self.not_empty = threading.Condition(self.mutex) + + self.not_full = threading.Condition(self.mutex) + + self.all_tasks_done = threading.Condition(self.mutex) + self.unfinished_tasks = 0 + self.closed = False + + self.serializer = serializer + + def __len__(self): + self.qsize() + + def __repr__(self): + return "<< CveXploreQueue:{} >>".format(self.name) + + def __iter__(self): + return self + + def __next__(self): + with self.mutex: + item = self.get(timeout=1) + if item is not None: + if isinstance(item, DatabaseAction): + item = item.entry + return item + else: + raise StopIteration + + def task_done(self): + with self.all_tasks_done: + unfinished = self.unfinished_tasks - 1 + if unfinished <= 0: + if unfinished < 0: + raise ValueError('task_done() called too many times') + self.all_tasks_done.notify_all() + self.unfinished_tasks = unfinished + + def join(self): + with self.all_tasks_done: + while self.unfinished_tasks and not self.closed: + self.all_tasks_done.wait() + + def qsize(self): + with self.mutex: + return self._qsize() + + def empty(self): + with self.mutex: + return not self._qsize() + + def full(self): + with self.mutex: + return 0 < self.maxsize <= self._qsize() + + def put_nowait(self, item): + return self.put(item, block=False) + + def get_nowait(self): + return self.get(block=False) + + def _init(self, maxsize): + self.queue = deque() + + def _qsize(self): + return len(self.queue) + + def _put(self, item): + self.queue.append(self.serializer.encode(item)) + + def _get(self): + item = self.serializer.decode(self.queue.popleft()) + if isinstance(item, DatabaseAction): + item = item.entry + return item + + def close(self): + with self.mutex: + self.closed = True + self.not_empty.notify_all() + self.not_full.notify_all() + self.all_tasks_done.notify_all() + + def getall(self, block=True, timeout=None): + with self.not_empty: + if self.closed: + raise QueueClosed() + if not block: + if not self._qsize(): + raise Empty + elif timeout is None: + while not self._qsize() and not self.closed: + self.not_empty.wait() + elif timeout < 0: + raise ValueError("'timeout' must be a non-negative number") + else: + endtime = time() + timeout + while not self._qsize() and not self.closed: + remaining = endtime - time() + if remaining <= 0.0: + raise Empty + self.not_empty.wait(remaining) + if self.closed: + raise QueueClosed() + items = list(self.queue) + items = list(map(self.serializer.decode, items)) + self.queue.clear() + self.not_full.notify() + return items + + def iter_queue(self): + with self.not_empty: + item = self.get(timeout=1) + if item is not None: + if isinstance(item, DatabaseAction): + item = item.entry + yield item + else: + self.not_full.notify() + raise StopIteration + + def clear(self): + with self.not_empty: + self.queue.clear() + self.not_full.notify() + + def put(self, item, block=True, timeout=None): + with self.not_full: + if self.closed: + raise QueueClosed() + if self.maxsize > 0: + if not block: + if self._qsize() >= self.maxsize: + raise Full + elif timeout is None: + while self._qsize() >= self.maxsize and not self.closed: + self.not_full.wait() + elif timeout < 0: + raise ValueError("'timeout' must be a non-negative number") + else: + endtime = time() + timeout + while self._qsize() >= self.maxsize and not self.closed: + remaining = endtime - time() + if remaining <= 0.0: + raise Full + self.not_full.wait(remaining) + if self.closed: + raise QueueClosed() + self._put(item) + self.unfinished_tasks += 1 + self.not_empty.notify() + + def get(self, block=True, timeout=None): + with self.not_empty: + if self.closed: + raise QueueClosed() + if not block: + if not self._qsize(): + raise Empty + elif timeout is None: + while not self._qsize() and not self.closed: + self.not_empty.wait() + elif timeout < 0: + raise ValueError("'timeout' must be a non-negative number") + else: + endtime = time() + timeout + while not self._qsize() and not self.closed: + remaining = endtime - time() + if remaining <= 0.0: + raise Empty + self.not_empty.wait(remaining) + if self.closed: + raise QueueClosed() + item = self._get() + self.not_full.notify() + return item