From 67753b9b2918da7863aa3c44e70c2872f7ee8946 Mon Sep 17 00:00:00 2001 From: Anil Vaza Date: Fri, 11 Jun 2021 19:32:20 +0530 Subject: [PATCH] Added support in siem.py to communicate with legacy-siem service --- README.md | 9 +- api_client.py | 566 +++++++++++++++++++++++++++++ config.ini | 16 + config.py | 54 +-- name_mapping.py | 112 +----- siem.py | 589 +++++++++++++------------------ state.py | 116 ++++++ test_regression.py | 109 +++--- tests/unit/test_api_client.py | 442 +++++++++++++++++++++++ tests/unit/test_call_endpoint.py | 217 ------------ tests/unit/test_config.py | 48 +++ tests/unit/test_name_mapping.py | 119 +++++++ tests/unit/test_siem.py | 310 +++++++--------- tests/unit/test_state.py | 79 +++++ 14 files changed, 1840 insertions(+), 946 deletions(-) create mode 100644 api_client.py create mode 100644 state.py create mode 100644 tests/unit/test_api_client.py delete mode 100644 tests/unit/test_call_endpoint.py create mode 100644 tests/unit/test_config.py create mode 100644 tests/unit/test_name_mapping.py create mode 100644 tests/unit/test_state.py diff --git a/README.md b/README.md index 3633119..14af592 100755 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ Any issue discovered using the script should be reported to Sophos Support. The script in this directory allows you to use the Sophos Central API to get data into your SIEM solution. -Access to the APIs requires an access token that can be setup in the Sophos Central UI by going to System Settings from the navigation bar and then selecting API Token Management. From this page, you can click the Add Token button to create a new token. +Access to the APIs requires an access token or API Credentials that can be setup in the Sophos Central UI by going to System Settings from the navigation bar and then selecting API Token Management or API Credentials. From this page, you can click the Add Token button to create a new token. Here is more information available on how to setup API Token: https://community.sophos.com/kb/en-us/125169 You can view API Swagger Specification by accessing API Access URL from the access token created under Api Token Management in Sophos Central UI. @@ -24,7 +24,7 @@ You can view API Swagger Specification by accessing API Access URL from the acce Download and extract from [here](https://github.com/sophos/Sophos-Central-SIEM-Integration/archive/v1.1.0.zip) for the latest release. For older version, please consult the Releases section below. For changes to the API, please consult the API Updates section below. -The script requires Python 2.7.9+ to run. +The script requires Python 3.5+ to run. #### Releases #### @@ -58,11 +58,14 @@ config.ini is a configuration file that exists by default in the siem-scripts fo ##### Here are the steps to configure the script: 1. Open config.ini in a text editor. 2. Under 'API Access URL + Headers' in the config file, copy and paste the API Access URL + Headers block from the Api Token Management page in Sophos Central. +3. Under Client ID and Client Secret in the config file, copy and paste the API Credentials from the API Token Management page in Sophos Central. +4. Under Customer tenant id in the config file, you can mention the tenant id for which you want to fetch alerts and events. ##### Optional configuration steps: 1. Under json, cef or keyvalue, you could choose the preferred output of the response i.e. json, cef or keyvalue. 2. Under filename, you can specify the filename that your output would be saved to. Options are syslog, stdout or any custom file name. Custom files are created in a folder named log. 3. If you are using syslog then under syslog properties in the config file, configure address, facility and socktype. +4. under state_file_path, specify the full or relative path to the cache file (with a ".json" extension) ### Running the script @@ -73,7 +76,7 @@ For more options and help on running the script run 'python siem.py -h' ### License -Copyright 2016 Sophos Limited +Copyright 2016-2021 Sophos Limited Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at: http://www.apache.org/licenses/LICENSE-2.0 diff --git a/api_client.py b/api_client.py new file mode 100644 index 0000000..1617193 --- /dev/null +++ b/api_client.py @@ -0,0 +1,566 @@ +#!/usr/bin/env python + +# Copyright 2019-2021 Sophos Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. +# You may obtain a copy of the License at: http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing permissions and limitations under the +# License. +# +import sys +import calendar + +import urllib.request as urlrequest +import urllib.error as urlerror +from urllib.parse import urlencode + +import datetime +import json +import logging +import logging.handlers +import os +import socket +import name_mapping +from random import randint +import time +import config + + +SYSLOG_SOCKTYPE = {"udp": socket.SOCK_DGRAM, "tcp": socket.SOCK_STREAM} + +# Initialize the SIEM_LOGGER +SIEM_LOGGER = logging.getLogger("SIEM") +SIEM_LOGGER.setLevel(logging.INFO) +SIEM_LOGGER.propagate = False +logging.basicConfig(format="%(message)s") + +EVENTS_V1 = "/siem/v1/events" +ALERTS_V1 = "/siem/v1/alerts" + +EVENT_TYPE = "event" +ALERT_TYPE = "alert" + +ENDPOINT_MAP = { + "event": [EVENTS_V1], + "alert": [ALERTS_V1], + "all": [EVENTS_V1, ALERTS_V1], +} + +# Initialize the SIEM_LOGGER +SIEM_LOGGER = logging.getLogger("SIEM") + + +class ApiClient: + def __init__(self, endpoint, options, config, state): + """Class providing alerts and events data""" + + self.state = state + self.state_data = state.state_data + self.endpoint = endpoint + self.options = options + self.config = config + logdir = self.create_log_dir() + self.add_siem_logeer_handler(logdir) + self.opener = self.create_request_builder() + self.get_noisy_event_types = self.get_noisy_event_types() + + def log(self, log_message): + """Write the log. + Arguments: + log_message {string} -- log content + """ + if not self.options.quiet: + sys.stderr.write("%s\n" % log_message) + + def get_noisy_event_types(self): + """Return noisy event types + Returns: + list -- noisy event type list + """ + return [k for k, v in name_mapping.TYPE_HANDLERS.items() if not v] + + def create_request_builder(self): + """Create the request build + Returns: + dict -- request builder + """ + if self.options.debug: + handler = urlrequest.HTTPSHandler(debuglevel=1) + else: + handler = urlrequest.HTTPSHandler() + + return urlrequest.build_opener(handler) + + def create_log_dir(self): + """Create the log directory + Returns: + log_dir {string} -- log directory path + """ + if "SOPHOS_SIEM_HOME" in os.environ: + app_path = os.environ["SOPHOS_SIEM_HOME"] + else: + app_path = os.path.join(os.getcwd()) + + log_dir = os.path.join(app_path, "log") + if not os.path.exists(log_dir): + try: + os.makedirs(log_dir) + return log_dir + except OSError as e: + self.log("Failed to create %s, %s" % (log_dir, str(e))) + sys.exit(1) + return log_dir + + def get_syslog_facilities(self): + """Create a mapping between our names and the python syslog defines + Returns: + out {dict} -- syslog facilities + """ + out = {} + possible = ( + "auth cron daemon kern lpr mail news syslog user uucp " + "local0 local1 local2 local3 local4 local5 local6 local7".split() + ) + for facility in possible: + out[facility] = getattr( + logging.handlers.SysLogHandler, "LOG_%s" % facility.upper() + ) + return out + + def jitter(self): + """ Added the rendom sleep """ + time.sleep(randint(0, 10)) + + def add_siem_logeer_handler(self, logdir): + """Added the log handler + Arguments: + logdir {string}: log directory path + """ + if self.config.filename == "syslog": + syslog_facility = self.get_syslog_facilities() + facility = syslog_facility[self.config.facility] + address = self.config.address + if ":" in address: + result = address.split(":") + host = result[0] + port = result[1] + address = (host, int(port)) + + socktype = SYSLOG_SOCKTYPE[self.config.socktype] + logging_handler = logging.handlers.SysLogHandler( + address, facility, socktype + ) + elif self.config.filename == "stdout": + logging_handler = logging.StreamHandler(sys.stdout) + else: + logging_handler = logging.FileHandler( + os.path.join(logdir, self.config.filename), "a", encoding="utf-8" + ) + SIEM_LOGGER.addHandler(logging_handler) + + def get_past_datetime(self, hours): + """Get the past datetime based on hours argument + Arguments: + hours {string}: number + Returns: + string -- return past datetime + """ + return int( + calendar.timegm( + ( + ( + datetime.datetime.utcnow() - datetime.timedelta(hours=hours) + ).timetuple() + ) + ) + ) + + def request_url(self, host_url, body, header, retry_count=3): + """Make the request and return response data or throw exception + Arguments: + host_url {string}: req url + body {dict}: req body + header {dict}: req header + retry_count {number}: retry request count + Returns: + response -- response data or throw exception + """ + for i in range(0, retry_count): + try: + data = urlencode(body).encode("utf-8") if body is not None else body + request = urlrequest.Request(host_url, data, header) + response = self.opener.open(request) + except urlerror.HTTPError as e: + if e.code in (503, 504, 403, 429): + self.log( + 'Error "%s" (code %s) on attempt #%s of %s, retrying' + % (e, e.code, i, retry_count) + ) + if i < retry_count: + continue + self.log( + "Error during request. Error code: %s, Error message: %s" + % (e.code, e.read()) + ) + raise + return response.read() + + def get_alerts_or_events(self): + """Get alerts/events data + Returns: + results {list} -- alerts/events response data + """ + endpoint_name = self.endpoint.rsplit("/", 1)[-1] + + if self.options.light and self.endpoint == ENDPOINT_MAP["event"][0]: + self.log( + "Light mode - not retrieving:%s" % "; ".join(self.get_noisy_event_types) + ) + + self.log( + "Config endpoint=%s, filename='%s' and format='%s'" + % (self.endpoint, self.config.filename, self.config.format) + ) + + since = False + if self.options.since: + since = self.options.since + self.log("Retrieving results since: %s" % since) + else: + self.log("No datetime found, defaulting to last 12 hours for results") + since = self.get_past_datetime(12) + + if ( + self.config.client_id + and self.config.client_secret + ): + tenant_obj = self.get_tenants_from_sophos() + + if "items" in tenant_obj: + results = self.make_credentials_request( + since, endpoint_name, tenant_obj + ) + else: + self.log("Error :: %s" % tenant_obj["error"]) + raise Exception(tenant_obj['error']) + else: + token_data = config.Token(self.config.token_info) + results = self.make_token_request( + since, endpoint_name, token_data + ) + return results + + def call_endpoint(self, api_host, default_headers, args): + """Execute the API request + Arguments: + api_host {string}: host name + default_headers {object}: request header + args {string}: request argument + Returns: + events {list} -- API response + """ + events_request_url = "%s%s?%s" % (api_host, self.endpoint, args) + self.log("URL: %s" % events_request_url) + events_response = self.request_url(events_request_url, None, default_headers) + if self.options.debug: + self.log("RESPONSE: %s" % events_response) + events = json.loads(events_response) + return events + + def get_alerts_or_events_req_args(self, params): + """Convert the params to query string + Arguments: + params {dict}: params object + Returns: + args {string} -- arguments string + """ + if self.options.light and self.endpoint == ENDPOINT_MAP["event"][0]: + types = ",".join(["%s" % t for t in self.get_noisy_event_types]) + types = "exclude_types=" + types + args = "&".join( + ["%s=%s" % (k, v) for k, v in params.items()] + + [ + types, + ] + ) + else: + args = "&".join(["%s=%s" % (k, v) for k, v in params.items()]) + return args + + def make_token_request(self, since, endpoint_name, token): + """Make alerts/events request by using token info. + Arguments: + since {number}: Return results from specified time + endpoint_name {string}: endpoint name + token {string} -- token + Returns: + dict -- yield event/alert object + """ + state_data_key = endpoint_name + "LastFetched" + default_headers = { + "Content-Type": "application/json; charset=utf-8", + "Accept": "application/json", + "X-Locale": "en", + "Authorization": token.authorization, + "x-api-key": token.api_key, + } + token_val = token.authorization.split()[1] + + params = {"limit": 1000} + + if ( + "account" in self.state_data + and token_val in self.state_data["account"] + and state_data_key in self.state_data["account"][token_val] + ): + params["cursor"] = self.state_data["account"][token_val][state_data_key] + self.jitter() + else: + params["from_date"] = since + + args = self.get_alerts_or_events_req_args(params) + + while True: + events = self.call_endpoint(token.url, default_headers, args) + + if "items" in events and len(events["items"]) > 0: + for e in events["items"]: + e["datastream"] = EVENT_TYPE if (self.endpoint == EVENTS_V1) else ALERT_TYPE + yield e + else: + self.log( + "No new %s data retrieved from the API" + % endpoint_name + ) + data_key = "account." + token_val + "." + state_data_key + self.state.save_state(data_key, events["next_cursor"]) + if not events["has_more"]: + break + else: + params["cursor"] = events["next_cursor"] + params.pop("from_date", None) + + def make_credentials_request(self, since, endpoint_name, tenant_obj): + """Make alerts/events request by using API credentials. + Arguments: + since {number}: Return results from specified time + endpoint_name {string}: endpoint name + tenant_obj {object} -- tenant object + Returns: + dict -- yield event/alert object + """ + state_data_key = endpoint_name + "LastFetched" + for tenant in tenant_obj["items"]: + tenant_id = tenant["id"] + default_headers = { + "X-Tenant-ID": tenant_id, + "Authorization": "Bearer " + tenant_obj["access_token"], + } + params = {"limit": 1000} + + if ( + "tenants" in self.state_data + and tenant_id in self.state_data["tenants"] + and state_data_key in self.state_data["tenants"][tenant_id] + ): + params["cursor"] = self.state_data["tenants"][tenant_id][state_data_key] + self.jitter() + else: + params["from_date"] = since + + args = self.get_alerts_or_events_req_args(params) + + while True: + dataRegionURL = tenant["apiHost"] if 'idType' not in tenant else tenant['apiHosts']['dataRegion'] + events = self.call_endpoint(dataRegionURL, default_headers, args) + + if "items" in events and len(events["items"]) > 0: + for e in events["items"]: + e["datastream"] = ( + EVENT_TYPE if (self.endpoint == EVENTS_V1) else ALERT_TYPE + ) + yield e + else: + self.log( + "No new %s data retrieved from the API" + % endpoint_name + ) + cursor_key = "tenants." + tenant_id + "." + state_data_key + data_region_url_key = "tenants." + tenant_id + ".dataRegionUrl" + last_run_key = "tenants." + tenant_id + ".lastRunAt" + + self.state.save_state(cursor_key, events["next_cursor"]) + self.state.save_state(data_region_url_key, dataRegionURL) + self.state.save_state(last_run_key, time.time()) + if not events["has_more"]: + break + else: + params["cursor"] = events["next_cursor"] + params.pop("from_date", None) + + def get_tenants_from_sophos(self): + """Fetch the tenants for partner or organization. + Get the tenants by calling Sophos tenant API. + Returns: + dict -- response containing either list of tenant or error + """ + self.log("Fetching the tenants/customers list by calling the Sophos Cental API") + response = self.get_sophos_jwt() + + if "access_token" in response: + access_token = response["access_token"] + whoami_response = self.get_whoami_data(access_token) + if "id" in whoami_response: + if ( + whoami_response["idType"] == "partner" + or whoami_response["idType"] == "organization" + ): + tenant_data = self.get_partner_organization_tenants( + whoami_response, access_token + ) + tenant_data["access_token"] = access_token + else: + if ( + self.config.tenant_id != "" + and self.config.tenant_id + != whoami_response["id"] + ): + raise Exception( + "Configuration file mention tenant id not matched with whoami data tenant id" + ) + else: + tenant_data = {"items": [whoami_response]} + tenant_data["access_token"] = access_token + return tenant_data + else: + self.log( + "Whoami data not found for client id :: %s" + % self.config.client_id + ) + return whoami_response + else: + self.log( + "JWT token not found for client id :: %s" + % self.config.client_id + ) + return response + + def get_sophos_jwt(self): + """Fetch the Sophos JWT access token. + Get the token by calling Sophos API. + Returns: + dict -- response containing either of jwt token or error + """ + self.log("fetching access_token from sophos") + client_id = self.config.client_id + client_secret = self.config.client_secret + body = { + "grant_type": "client_credentials", + "scope": "token", + "client_id": client_id, + "client_secret": client_secret, + } + self.log("body :: %s" % str(body)) + current_time = time.time() + cache_client_data = ( + self.state_data["account"][client_id] + if "account" in self.state_data and client_id in self.state_data["account"] + else "" + ) + + if cache_client_data and current_time < cache_client_data["jwtExpiresAt"]: + self.log("return token from cache :: %s" % cache_client_data["jwt"]) + return {"access_token": cache_client_data["jwt"]} + else: + try: + response = self.request_url( + self.config.auth_url, body, {}, retry_count=3 + ) + response_data = json.loads(response) + + self.state.save_state( + "account.%s.jwt" % client_id, response_data["access_token"] + ) + self.state.save_state( + "account.%s.jwtExpiresAt" % client_id, + time.time() + (response_data["expires_in"] - 120), + ) + self.log("response :: %s" % str(response_data)) + return response_data + except json.decoder.JSONDecodeError as e: + self.log("Sophos Token API response not in valid json format") + return {"error": e} + except Exception as e: + self.log("Error :: %s" % e) + return {"error": e} + + def get_whoami_data(self, access_token): + """Fetch the Whoami data. + Get the customer/partner/organization data by calling Whoami API. + Arguments: + access_token {string}' JWT token value (default: {None}) + Returns: + dict -- response containing whoami response or error + """ + self.log("fetching whoami data") + try: + whoami_url = f"https://{self.config.api_host}/whoami/v1" + default_headers = {"Authorization": "Bearer " + access_token} + whoami_response = self.request_url(whoami_url, None, default_headers) + + self.log("Whoami response: %s" % whoami_response) + whoami_data = json.loads(whoami_response) + self.state.save_state( + "account.%s.whoami" % self.config.client_id, whoami_data + ) + return whoami_data + except json.decoder.JSONDecodeError as e: + self.log("Sophos whoami API response not in json format") + return {"error": e} + except Exception as e: + self.log("Error :: %s" % e) + return {"error": e} + + def get_partner_organization_tenants(self, whoami_response, access_token): + """Get the tenants for partner and organization by calling tenant API. + Arguments: + whoami_response {object}: whoami data + access_token {string}' JWT token value (default: {None}) + Returns: + dict -- response containing whoami response or error + """ + tenants = {} + try: + default_headers = { + "Authorization": "Bearer " + access_token, + "X-Partner-ID": whoami_response["id"], + } + tenant_url = ( + whoami_response["apiHosts"]["global"] + + "/" + + whoami_response["idType"] + + "/v1/tenants" + ) + tenant_response = self.request_url(tenant_url, None, default_headers, 1) + + self.log("Tenant response: %s" % tenant_response) + tenants = json.loads(tenant_response) + tenants["items"] = list( + filter(lambda tenant: tenant["id"] == self.config.tenant_id, tenants["items"]) + ) + if len(tenants["items"]) > 0: + return tenants + else: + raise Exception( + "Configuration file mention tenant id not matched with whoami data tenant id" + ) + except json.decoder.JSONDecodeError as e: + self.log("Sophos partner tenant API response not in json format") + return {"error": e} + except Exception as e: + self.log("Error :: %s" % e) + return {"error": e} diff --git a/config.ini b/config.ini index d0f91dd..41dacc3 100755 --- a/config.ini +++ b/config.ini @@ -4,6 +4,19 @@ token_info = +# Client ID and Client Secret for partner +# +client_id = +client_secret = +# Customer tenant Id +tenant_id = + +# Host URL for Oauth token +auth_url = + +# whoami API host url +api_host = + # format can be json, cef or keyvalue format = json @@ -20,3 +33,6 @@ endpoint = event address = /var/run/syslog facility = daemon socktype = udp + +# cache file full or relative path (with a ".json" extension) +state_file_path = state/siem_sophos.json \ No newline at end of file diff --git a/config.py b/config.py index d39364d..37eba09 100644 --- a/config.py +++ b/config.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -# Copyright 2019 Sophos Limited +# Copyright 2019-2021 Sophos Limited # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in # compliance with the License. @@ -12,26 +12,20 @@ # License. # -import unittest -import shutil -import tempfile -import os import re -try: - import ConfigParser -except ImportError: - import configparser as ConfigParser +import configparser as ConfigParser class Config: """Class providing config values""" + def __init__(self, path): """Open the config file""" self.config = ConfigParser.ConfigParser() self.config.read(path) - + def __getattr__(self, name): - return self.config.get('login', name) + return self.config.get("login", name) class Token: @@ -43,41 +37,3 @@ def __init__(self, token_txt): self.url = m.group("url") self.api_key = m.group("api_key") self.authorization = m.group("authorization").strip() - - -# -# TEST CODE TEST CODE TEST CODE TEST CODE TEST CODE TEST CODE TEST CODE TEST CODE TEST CODE TEST CODE -# TEST CODE TEST CODE TEST CODE TEST CODE TEST CODE TEST CODE TEST CODE TEST CODE TEST CODE TEST CODE -# - - -class TestConfig(unittest.TestCase): - """Test Config file items are exposed as attributes on config object""" - - def setUp(self): - self.tmpdir = tempfile.mkdtemp(prefix="config_test", dir=".") - - def tearDown(self): - if os.path.exists(self.tmpdir): - shutil.rmtree(self.tmpdir) - - def testReadingWhenAttributeExists(self): - cfg_path = os.path.join(self.tmpdir, "config.ini") - with open(cfg_path, "wb") as fp: - fp.write("[login]\ntoken_info = MY_TOKEN\n".encode("utf-8")) - cfg = Config(cfg_path) - self.assertEqual(cfg.token_info, "MY_TOKEN") - - -class TestToken(unittest.TestCase): - """Test the token gets parsed""" - def testParse(self): - txt = " url: https://anywhere.com/api, x-api-key: random, Authorization: Basic KJNKLJNjklNLKHB= " - t = Token(txt) - self.assertEqual(t.url, "https://anywhere.com/api") - self.assertEqual(t.api_key, "random") - self.assertEqual(t.authorization, "Basic KJNKLJNjklNLKHB=") - - -if __name__ == '__main__': - unittest.main() diff --git a/name_mapping.py b/name_mapping.py index fef9a4f..0e5045c 100644 --- a/name_mapping.py +++ b/name_mapping.py @@ -1,4 +1,4 @@ -# Copyright 2019 Sophos Limited +# Copyright 2019-2021 Sophos Limited # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in # compliance with the License. @@ -11,13 +11,11 @@ # import re -import unittest -import copy threat_regex = re.compile("'(?P.*?)'.+'(?P.*?)'") -# What to do with the different types of event. None indicates drop the event, otherwise a regex extracts the +# What to do with the different types of event. None indicates drop the event, otherwise a regex extracts the # various fields and inserts them into the event dictionary. TYPE_HANDLERS = { "Event::Endpoint::Threat::Detected": threat_regex, @@ -60,7 +58,7 @@ def update_fields(log, data): if data[u'type'] in TYPE_HANDLERS: prog_regex = TYPE_HANDLERS[data[u'type']] - if not prog_regex: + if not prog_regex: return result = prog_regex.search(data[u'name']) if not result: @@ -74,107 +72,3 @@ def update_fields(log, data): # Update the record with the split out parameters data.update(result.groupdict()) - - -# -# TEST CODE TEST CODE TEST CODE TEST CODE TEST CODE TEST CODE TEST CODE TEST CODE TEST CODE TEST CODE -# TEST CODE TEST CODE TEST CODE TEST CODE TEST CODE TEST CODE TEST CODE TEST CODE TEST CODE TEST CODE -# - - -def contains(dict_outer, dict_inner): - return all(item in dict_outer.items() for item in dict_inner.items()) - - -class TestNameExtraction(unittest.TestCase): - """Test logging output""" - - def setUp(self): - self.output = [] - - def tearDown(self): - pass - - def log(self, s): - self.output.append(s) - - def testUpdateNameDLPValid(self): - """DLP event with data that can be extracted""" - data = { - "type": "Event::Endpoint::DataLossPreventionUserAllowed", - "name": u"An \u2033allow transfer on acceptance by user\u2033 action was taken. " - u"Username: WIN10CLOUD2\\Sophos Rule names: \u2032test\u2032 User action: File open " - u"Application Name: Google Chrome Data Control action: Allow " - u"File type: Plain text (ASCII/UTF-8) File size: 36 " - u"Source path: C:\\Users\\Sophos\\Desktop\\test.txt" - } - expected = { - "type": "Event::Endpoint::DataLossPreventionUserAllowed", - "name": "allow transfer on acceptance by user", - "user": "WIN10CLOUD2\\Sophos", - "rule": "test", - "user_action": "File open", - "app_name": "Google Chrome", - "action": "Allow", - "file_type": "Plain text (ASCII/UTF-8)", - "file_size": "36", - "file_path": "C:\\Users\\Sophos\\Desktop\\test.txt" - } - update_fields(self.log, data) - self.assertTrue(all(item in data.items() for item in expected.items())) - self.assertEqual(len(self.output), 0) - - def testUpdateNameThreatValid(self): - """Threat event with data that can be extracted""" - data = {"type": "Event::Endpoint::Threat::CleanedUp", "name": u"Threat 'EICAR' in 'myfile.com' "} - expected = { - "type": "Event::Endpoint::Threat::CleanedUp", - "name": u"EICAR", - "filePath": "myfile.com", - "detection_identity_name": "EICAR" - } - update_fields(self.log, data) - self.assertTrue(contains(data, expected)) # expected data present - self.assertEqual(len(self.output), 0) # no error - - def testUpdateNameInvalid(self): - """A known type, but information can't be extracted (regex mismatch)""" - data = {"type": "Event::Endpoint::DataLossPreventionUserAllowed", "name": u"XXXX Garbage data XXXX"} - before = copy.copy(data) - update_fields(self.log, data) - self.assertEqual(len(self.output), 1) # a line of error output, when the function bails. - self.assertEqual(data, before) # ... and data remains unchanged - - def testUpdateNameFromDescription(self): - """Ensure the name gets updated from the description, if present""" - data = {"type": "", "description": "XXX"} - expected = copy.copy(data) - expected["name"] = "XXX" - update_fields(self.log, data) - self.assertEqual(data, expected) - - def testInvalidType(self): - """Ensure that nothing gets changed when the type isn't recognised""" - data = {"type": "<>", "name": "some name"} - expected = copy.copy(data) - update_fields(self.log, data) - self.assertEqual(len(self.output), 0) # not considered an error - self.assertEqual(data, expected) - - def testSkippedType(self): - """Ensure that entry is skipped if it's to be ignored.""" - # First find an event type that is set to 'None' - toskip = None - for k, v in TYPE_HANDLERS.items(): - if not v: - toskip = k - break - data = {"type": toskip, "name": "some name"} - expected = copy.copy(data) - update_fields(self.log, data) - self.assertEqual(len(self.output), 0) # not considered an error - self.assertEqual(data, expected) - - -if __name__ == '__main__': - unittest.main() diff --git a/siem.py b/siem.py index 4f960b0..15e3a91 100755 --- a/siem.py +++ b/siem.py @@ -1,9 +1,9 @@ #!/usr/bin/env python -# Copyright 2019 Sophos Limited +# Copyright 2019-2021 Sophos Limited # -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in -# compliance with the License. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. # You may obtain a copy of the License at: http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License is @@ -12,91 +12,36 @@ # License. # import sys -import calendar - - -try: - # Python 2 - import urllib2 as urlrequest - import urllib2 as urlerror -except ImportError: - # Python 3 - import urllib.request as urlrequest - import urllib.error as urlerror - - -import datetime import json import logging import logging.handlers import os -import pickle import re -import socket -import time - +import state from optparse import OptionParser -from random import randint import name_mapping import config +import api_client - -def get_syslog_facilities(): - """Create a mapping between our names and the python syslog defines""" - out = {} - possible = "auth cron daemon kern lpr mail news syslog user uucp " \ - "local0 local1 local2 local3 local4 local5 local6 local7".split() - for facility in possible: - out[facility] = getattr(logging.handlers.SysLogHandler, "LOG_%s" % facility.upper()) - return out - - -SYSLOG_FACILITY = get_syslog_facilities() - - -SYSLOG_SOCKTYPE = {'udp': socket.SOCK_DGRAM, - 'tcp': socket.SOCK_STREAM - } - - -VERSION = '1.0.0' -LIGHT = False -DEBUG = False +VERSION = "1.0.0" QUIET = False -MISSING_VALUE = 'NA' -DEFAULT_ENDPOINT = 'event' -PREFIX_PATTERN = re.compile(r'([|\\])') -EXTENSION_PATTERN = re.compile(r'([=\\])') +MISSING_VALUE = "NA" +DEFAULT_ENDPOINT = "event" -SEVERITY_MAP = {'none': 0, - 'low': 1, - 'medium': 5, - 'high': 8, - 'very_high': 10} +SEVERITY_MAP = {"none": 0, "low": 1, "medium": 5, "high": 8, "very_high": 10} - -def get_noisy_event_types(): - return [k for k, v in name_mapping.TYPE_HANDLERS.items() if not v] - - -NOISY_EVENTTYPES = get_noisy_event_types() - -EVENTS_V1 = '/siem/v1/events' -ALERTS_V1 = '/siem/v1/alerts' - -EVENT_TYPE = 'event' -ALERT_TYPE = 'alert' - -ENDPOINT_MAP = {'event': [EVENTS_V1], - 'alert': [ALERTS_V1], - 'all': [EVENTS_V1, ALERTS_V1]} - -CEF_CONFIG = {'cef.version': '0', 'cef.device_vendor': 'sophos', - 'cef.device_product': 'sophos central', 'cef.device_version': 1.0} +CEF_CONFIG = { + "cef.version": "0", + "cef.device_vendor": "sophos", + "cef.device_product": "sophos central", + "cef.device_version": 1.0, +} # CEF format from https://www.protect724.hpe.com/docs/DOC-1072 -CEF_FORMAT = ('CEF:%(version)s|%(device_vendor)s|%(device_product)s|' - '%(device_version)s|%(device_event_class_id)s|%(name)s|%(severity)s|') +CEF_FORMAT = ( + "CEF:%(version)s|%(device_vendor)s|%(device_product)s|" + "%(device_version)s|%(device_event_class_id)s|%(name)s|%(severity)s|" +) CEF_MAPPING = { @@ -107,7 +52,6 @@ def get_noisy_event_types(): "device_event_class_id": "type", "name": "name", "severity": "severity", - # json to CEF extension mapping # Format # JSON_key: CEF_extension @@ -120,255 +64,56 @@ def get_noisy_event_types(): } # Initialize the SIEM_LOGGER -SIEM_LOGGER = logging.getLogger('SIEM') +SIEM_LOGGER = logging.getLogger("SIEM") SIEM_LOGGER.setLevel(logging.INFO) SIEM_LOGGER.propagate = False -logging.basicConfig(format='%(message)s') - - -def main(): - global LIGHT, DEBUG, QUIET - - if 'SOPHOS_SIEM_HOME' in os.environ: - app_path = os.environ['SOPHOS_SIEM_HOME'] - else: - # Setup path - app_path = os.path.join(os.getcwd()) - - config_file = os.path.join(app_path, 'config.ini') - - parser = OptionParser(description="Download event and/or alert data and output to various formats. " - "config.ini is a configuration file that exists by default in the siem-scripts " - "folder." - "Script keeps tab of its state, it will always pick-up from where it left-off " - "based on a state file stored in state folder. Set SOPHOS_SIEM_HOME environment " - "variable to point to the folder where config.ini, mapping files, state " - "and log folders will be located. state and log folders are created when the " - "script is run for the first time. ") - parser.add_option('-s', '--since', default=False, action='store', help="Return results since specified Unix " - "Timestamp, max last 24 hours, defaults to " - "last 12 hours if there is no state file") - parser.add_option('-c', '--config', default=config_file, action='store', help="Specify a configuration file, " - "defaults to config.ini") - parser.add_option('-l', '--light', default=False, action='store_true', help="Ignore noisy events - web control, " - "device control, update failure, " - "application allowed, (non)compliant") - parser.add_option('-d', '--debug', default=False, action='store_true', help="Print debug logs") - parser.add_option('-v', '--version', default=False, action='store_true', help="Print version") - parser.add_option('-q', '--quiet', default=False, action='store_true', help="Suppress status messages") - - options, args = parser.parse_args() - - if options.config is None: - parser.error("Need a config file specified") - - if options.version: - log(VERSION) - sys.exit(0) - if options.quiet: - QUIET = True - - # Read config file - cfg = config.Config(options.config) - token = config.Token(cfg.token_info) - - log("Config loaded, retrieving results for '%s'" % token.api_key) - log("Config retrieving results for '%s'" % token.authorization) - - if cfg.endpoint in ENDPOINT_MAP: - tuple_endpoint = ENDPOINT_MAP[cfg.endpoint] - else: - tuple_endpoint = ENDPOINT_MAP[DEFAULT_ENDPOINT] - - state_dir = os.path.join(app_path, 'state') - log_dir = os.path.join(app_path, 'log') - - create_log_and_state_dir(state_dir, log_dir) - - if options.light: - LIGHT = True - - if options.debug: - DEBUG = True - handler = urlrequest.HTTPSHandler(debuglevel=1) - else: - handler = urlrequest.HTTPSHandler() - opener = urlrequest.build_opener(handler) - - endpoint_config = {'format': cfg.format, - 'filename': cfg.filename, - 'state_dir': state_dir, - 'log_dir': log_dir, - 'since': options.since} - - if cfg.filename == 'syslog': - endpoint_config['facility'] = cfg.facility.strip() - endpoint_config['address'] = cfg.address.strip() - endpoint_config['socktype'] = cfg.socktype.strip() - - SIEM_LOGGER.addHandler(create_output_handler(endpoint_config)) - - for endpoint in tuple_endpoint: - process_endpoint(endpoint, opener, endpoint_config, token) - - -def process_endpoint(endpoint, opener, endpoint_config, token): - state_file_name = "siem_lastrun_" + endpoint.rsplit('/', 1)[-1] + ".obj" - state_file_path = os.path.join(endpoint_config['state_dir'], state_file_name) - if LIGHT and endpoint == ENDPOINT_MAP['event'][0]: - log("Light mode - not retrieving:%s" % '; '.join(NOISY_EVENTTYPES)) - - log("Config endpoint=%s, filename='%s' and format='%s'" % - (endpoint, endpoint_config['filename'], endpoint_config['format'])) - log("Config state_file='%s' and cwd='%s'" % (state_file_path, os.getcwd())) - cursor = False - since = False - if endpoint_config['since']: # Run since supplied datetime - since = endpoint_config['since'] - else: - try: # Run since last run (retrieve from state_file) - with open(state_file_path, 'rb') as f: - cursor = pickle.load(f) - except IOError: # Default to current time - since = int(calendar.timegm(((datetime.datetime.utcnow() - datetime.timedelta(hours=12)).timetuple()))) - log("No datetime found, defaulting to last 12 hours for results") - - if since is not False: - log('Retrieving results since: %s' % since) - else: - log('Retrieving results starting cursor: %s' % cursor) - - results = call_endpoint(opener, endpoint, since, cursor, state_file_path, token) - - if endpoint_config['format'] == 'json': - write_json_format(results) - elif endpoint_config['format'] == 'keyvalue': - write_keyvalue_format(results) - elif endpoint_config['format'] == 'cef': - write_cef_format(results) - else: - write_json_format(results) - - -def create_output_handler(endpoint_config): - if endpoint_config['filename'] == 'syslog': - facility = SYSLOG_FACILITY[endpoint_config['facility']] - address = endpoint_config['address'] - if ':' in address: - result = address.split(':') - host = result[0] - port = result[1] - address = (host, int(port)) - - socktype = SYSLOG_SOCKTYPE[endpoint_config['socktype']] - logging_handler = logging.handlers.SysLogHandler(address, facility, socktype) - elif endpoint_config['filename'] == 'stdout': - logging_handler = logging.StreamHandler(sys.stdout) - else: - logging_handler = logging. \ - FileHandler(os.path.join(endpoint_config['log_dir'], endpoint_config['filename']), 'a', encoding='utf-8') - return logging_handler +logging.basicConfig(format="%(message)s") def write_json_format(results): + """Write JSON format data. + Arguments: + results {list}: data + """ for i in results: i = remove_null_values(i) update_cef_keys(i) name_mapping.update_fields(log, i) - SIEM_LOGGER.info(json.dumps(i, ensure_ascii=False) + u'\n') + SIEM_LOGGER.info(json.dumps(i, ensure_ascii=False, indent=4) + u"\n") def write_keyvalue_format(results): + """Write key value format data. + Arguments: + results {dict}: results + """ for i in results: i = remove_null_values(i) update_cef_keys(i) name_mapping.update_fields(log, i) - date = i[u'rt'] + date = i[u"rt"] # TODO: Spaces/quotes/semicolons are not escaped here, does it matter? events = list('%s="%s";' % (k, v) for k, v in i.items()) - SIEM_LOGGER.info(' '.join([date, ] + events) + u'\n') + SIEM_LOGGER.info( + " ".join( + [ + date, + ] + + events + ) + + u"\n" + ) def write_cef_format(results): + """Write CEF format data. + Arguments: + results {list}: data + """ for i in results: i = remove_null_values(i) name_mapping.update_fields(log, i) - SIEM_LOGGER.info(format_cef(flatten_json(i)) + u'\n') - - -def create_log_and_state_dir(state_dir, log_dir): - if not os.path.exists(state_dir): - try: - os.makedirs(state_dir) - except OSError as e: - log("Failed to create %s, %s" % (state_dir, str(e))) - sys.exit(1) - if not os.path.exists(log_dir): - try: - os.makedirs(log_dir) - except OSError as e: - log("Failed to create %s, %s" % (log_dir, str(e))) - sys.exit(1) - - -def call_endpoint(opener, endpoint, since, cursor, state_file_path, token): - default_headers = {'Content-Type': 'application/json; charset=utf-8', - 'Accept': 'application/json', - 'X-Locale': 'en', - 'Authorization': token.authorization, - 'x-api-key': token.api_key} - - params = { - 'limit': 1000 - } - if not cursor: - params['from_date'] = since - else: - params['cursor'] = cursor - jitter() - - while True: - if LIGHT and endpoint == ENDPOINT_MAP['event'][0]: - types = ','.join(["%s" % t for t in NOISY_EVENTTYPES]) - types = 'exclude_types=' + types - args = '&'.join(['%s=%s' % (k, v) for k, v in params.items()]+[types, ]) - else: - args = '&'.join(['%s=%s' % (k, v) for k, v in params.items()]) - events_request_url = '%s%s?%s' % (token.url, endpoint, args) - log("URL: %s" % events_request_url) - events_request = urlrequest.Request(events_request_url, None, default_headers) - - for k, v in default_headers.items(): - events_request.add_header(k, v) - - events_response = request_url(opener, events_request) - if DEBUG: - log("RESPONSE: %s" % events_response) - events = json.loads(events_response) - - # events looks like this - # { - # u'chart_detail': {u'2014-10-01T00:00:00.000Z': 3638}, - # u'event_counts': {u'Event::Endpoint::Compliant': 679, - # u'events': {} - # } - for e in events['items']: - e[u'datastream'] = EVENT_TYPE if(endpoint == EVENTS_V1) else ALERT_TYPE - yield e - - store_state(events['next_cursor'], state_file_path) - if not events['has_more']: - break - else: - params['cursor'] = events['next_cursor'] - params.pop('from_date', None) - - -def store_state(next_cursor, state_file_path): - # Store cursor - log("Next run will retrieve results using cursor %s\n" % next_cursor) - with open(state_file_path, 'wb') as f: - pickle.dump(next_cursor, f, protocol=2) + SIEM_LOGGER.info(format_cef(flatten_json(i)) + u"\n") # Flattening JSON objects in Python @@ -376,10 +121,10 @@ def store_state(next_cursor, state_file_path): def flatten_json(y): out = {} - def flatten(x, name=''): + def flatten(x, name=""): if type(x) is dict: for a in x: - flatten(x[a], name + a + '_') + flatten(x[a], name + a + "_") else: out[name[:-1]] = x @@ -388,39 +133,35 @@ def flatten(x, name=''): def log(s): + """Write the log. + Arguments: + log_message {string} -- log content + """ if not QUIET: - sys.stderr.write('%s\n' % s) - - -def jitter(): - time.sleep(randint(0, 10)) - - -def request_url(opener, request): - for i in [1, 2, 3]: # Some ops we simply retry - try: - response = opener.open(request) - except urlerror.HTTPError as e: - if e.code in (503, 504, 403, 429): - log('Error "%s" (code %s) on attempt #%s of 3, retrying' % (e, e.code, i)) - if i < 3: - continue - log('Error during request. Error code: %s, Error message: %s' % (e.code, e.read())) - raise - return response.read() + sys.stderr.write("%s\n" % s) def format_prefix(data): + """ pipe and backslash in header must be escaped. escape group with backslash + Arguments: + data {string}: data + Returns: + string -- backslash escape string + """ # pipe and backslash in header must be escaped # escape group with backslash - return PREFIX_PATTERN.sub(r'\\\1', data) + return re.compile(r"([|\\])").sub(r"\\\1", data) def format_extension(data): - # equal sign and backslash in extension value must be escaped - # escape group with backslash + """ equal sign and backslash in extension value must be escaped. escape group with backslash. + Arguments: + data : data + Returns: + string/list -- backslash escape string or return same value + """ if type(data) is str: - return EXTENSION_PATTERN.sub(r'\\\1', data) + return re.compile(r"([=\\])").sub(r"\\\1", data) else: return data @@ -431,14 +172,19 @@ def map_severity(severity): else: msg = 'The "%s" severity can not be mapped, defaulting to 0' % severity log(msg) - return SEVERITY_MAP['none'] + return SEVERITY_MAP["none"] def extract_prefix_fields(data): - # extract prefix fields and remove those from data dictionary - name_field = CEF_MAPPING['name'] - device_event_class_id_field = CEF_MAPPING['device_event_class_id'] - severity_field = CEF_MAPPING['severity'] + """ extract prefix fields and remove those from data dictionary + Arguments: + data {dict}: data + Returns: + fields {dict} -- fields object + """ + name_field = CEF_MAPPING["name"] + device_event_class_id_field = CEF_MAPPING["device_event_class_id"] + severity_field = CEF_MAPPING["severity"] name = data.get(name_field, MISSING_VALUE) name = format_prefix(name) @@ -452,17 +198,23 @@ def extract_prefix_fields(data): severity = map_severity(severity) data.pop(severity_field, None) - fields = {'name': name, - 'device_event_class_id': device_event_class_id, - 'severity': severity, - 'version': CEF_CONFIG['cef.version'], - 'device_vendor': CEF_CONFIG['cef.device_vendor'], - 'device_version': CEF_CONFIG['cef.device_version'], - 'device_product': CEF_CONFIG['cef.device_product']} + fields = { + "name": name, + "device_event_class_id": device_event_class_id, + "severity": severity, + "version": CEF_CONFIG["cef.version"], + "device_vendor": CEF_CONFIG["cef.device_vendor"], + "device_version": CEF_CONFIG["cef.device_version"], + "device_product": CEF_CONFIG["cef.device_product"], + } return fields def update_cef_keys(data): + """ Replace if there is a mapped CEF key + Arguments: + data {dict}: data + """ # Replace if there is a mapped CEF key for key, value in list(data.items()): new_key = CEF_MAPPING.get(key, key) @@ -473,6 +225,12 @@ def update_cef_keys(data): def format_cef(data): + """ Message CEF formatted + Arguments: + data {dict}: data + Returns: + data {str}: message + """ fields = extract_prefix_fields(data) msg = CEF_FORMAT % fields @@ -480,15 +238,166 @@ def format_cef(data): for index, (key, value) in enumerate(data.items()): value = format_extension(value) if index > 0: - msg += ' %s=%s' % (key, value) + msg += " %s=%s" % (key, value) else: - msg += '%s=%s' % (key, value) + msg += "%s=%s" % (key, value) return msg def remove_null_values(data): + """ Removed null value + Arguments: + data {dict}: data + Returns: + data {dict}: update data + """ return {k: v for k, v in data.items() if v is not None} +def parse_args_options(): + """ Parsed the command line arguments + Returns: + options {dict}: options data + """ + global QUIET + if "SOPHOS_SIEM_HOME" in os.environ: + app_path = os.environ["SOPHOS_SIEM_HOME"] + else: + # Setup path + app_path = os.path.join(os.getcwd()) + + config_file = os.path.join(app_path, "config.ini") + + parser = OptionParser( + description="Download event and/or alert data and output to various formats. " + "config.ini is a configuration file that exists by default in the siem-scripts " + "folder." + "Script keeps tab of its state, it will always pick-up from where it left-off " + "based on a state file stored in state folder. Set SOPHOS_SIEM_HOME environment " + "variable to point to the folder where config.ini, mapping files, state " + "and log folders will be located. state and log folders are created when the " + "script is run for the first time. " + ) + parser.add_option( + "-s", + "--since", + default=False, + action="store", + help="Return results since specified Unix " + "Timestamp, max last 24 hours, defaults to " + "last 12 hours if there is no state file", + ) + parser.add_option( + "-c", + "--config", + default=config_file, + action="store", + help="Specify a configuration file, " "defaults to config.ini", + ) + parser.add_option( + "-l", + "--light", + default=False, + action="store_true", + help="Ignore noisy events - web control, " + "device control, update failure, " + "application allowed, (non)compliant", + ) + parser.add_option( + "-d", "--debug", default=False, action="store_true", help="Print debug logs" + ) + parser.add_option( + "-v", "--version", default=False, action="store_true", help="Print version" + ) + parser.add_option( + "-q", + "--quiet", + default=False, + action="store_true", + help="Suppress status messages", + ) + + options, args = parser.parse_args() + + if options.config is None: + parser.error("Need a config file specified") + + if options.version: + log(VERSION) + sys.exit(0) + if options.quiet: + QUIET = True + + return options + + +def load_config(config_path): + """ Get config file data + Arguments: + config_path {str}: config file path + Returns: + cfg {dice}: config.ini data + """ + cfg = config.Config(config_path) + cfg.format = cfg.format.lower() + cfg.endpoint = cfg.endpoint.lower() + validate_format(cfg.format) + validate_endpoint(cfg.endpoint) + return cfg + +def validate_format(format): + if format not in ("json", "keyvalue", "cef"): + raise Exception("Invalid format in config.ini, format can be json, cef or keyvalue") + +def validate_endpoint(endpoint): + endpoint_map = api_client.ENDPOINT_MAP + if endpoint not in endpoint_map: + raise Exception("Invalid endpoint in config.ini, endpoint can be event, alert or all") + +def get_alerts_or_events(endpoint, options, config, state): + """ Get alerts/events data + Arguments: + endpoint {str}: endpoint name + options {dict}: options + config {dict}: config file details + state {dict}: state file details + """ + api_client_obj = api_client.ApiClient(endpoint, options, config, state) + results = api_client_obj.get_alerts_or_events() + + if config.format == "json": + write_json_format(results) + elif config.format == "keyvalue": + write_keyvalue_format(results) + elif config.format == "cef": + write_cef_format(results) + else: + write_json_format(results) + +def run(options, config_data, state): + """ Call the fetch alerts/events method + Arguments: + options {dict}: options + config_data {dict}: config file details + state {dict}: state file details + """ + endpoint_map = api_client.ENDPOINT_MAP + if config_data.endpoint in endpoint_map: + tuple_endpoint = endpoint_map[config_data.endpoint] + else: + tuple_endpoint = endpoint_map[DEFAULT_ENDPOINT] + + for endpoint in tuple_endpoint: + get_alerts_or_events( + endpoint, options, config_data, state + ) + + +def main(): + options = parse_args_options() + config_data = load_config(options.config) + state_data = state.State(options, config_data.state_file_path) + run(options, config_data, state_data) + if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/state.py b/state.py new file mode 100644 index 0000000..0e718f2 --- /dev/null +++ b/state.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python + +# Copyright 2019-2021 Sophos Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. +# You may obtain a copy of the License at: http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing permissions and limitations under the +# License. +# +import sys +import os +import json +from pathlib import Path + + +class State: + def __init__(self, options, state_file): + """Class create state file and providing state file data""" + + if state_file and Path(state_file).suffix != ".json": + raise SystemExit( + "Sophos state file is not in valid format. it's must be with a .json extension" + ) + self.options = options + if "SOPHOS_SIEM_HOME" in os.environ: + app_path = os.environ["SOPHOS_SIEM_HOME"] + else: + app_path = os.path.join(os.getcwd()) + + self.state_file = self.get_state_file(app_path, state_file) + self.create_state_dir(self.state_file) + self.state_data = self.load_state_file() + + def log(self, log_message): + """Write the log. + Arguments: + log_message {string} -- log content + """ + if not self.options.quiet: + sys.stderr.write("%s\n" % log_message) + + def create_state_dir(self, state_file): + """Create state directory + Arguments: + state_file {string}: state file path + """ + state_dir = os.path.dirname(state_file) + if not os.path.exists(state_dir): + try: + os.makedirs(state_dir) + except OSError as e: + raise SystemExit("Failed to create %s, %s" % (state_dir, str(e))) + + def get_state_file(self, app_path, state_file): + """Return state cache file path + Arguments: + app_path {string}: application path + state_file {string}: state file path + Returns: + dict -- state file path + """ + if not state_file: + return os.path.join(app_path, "state", "siem_sophos.json") + else: + return ( + state_file + if os.path.isabs(state_file) + else os.path.join(app_path, state_file) + ) + + def load_state_file(self): + """Get state file data + Returns: + dict -- Return state file data or exit if found any error + """ + try: + with open(self.state_file, "rb") as f: + return json.load(f) + except IOError: + self.log("Sophos state file not found") + except json.decoder.JSONDecodeError: + raise SystemExit("Sophos state file not in valid JSON format") + return {} + + def save_state(self, state_data_key, state_data_value): + """save data in state file. Data store in nested object by splitting key with `.` separator + Arguments: + state_data_key {string}: state key + state_data_value {string}: state value + """ + # Store state + key_arr = state_data_key.split(".") + sub_data = self.state_data + for item in key_arr[0:-1]: + if item not in sub_data.keys(): + sub_data[item] = {} + sub_data = sub_data[item] + sub_data[key_arr[-1]] = state_data_value + + self.write_state_file(json.dumps(self.state_data, indent=4)) + + def write_state_file(self, data): + """Write data in state file + Arguments: + data {dict}: state data object + """ + with open(self.state_file, "w") as f: + try: + f.write(data) + except Exception as e: + self.log("Error :: %s" % e) + pass diff --git a/test_regression.py b/test_regression.py index 6653421..c0af325 100755 --- a/test_regression.py +++ b/test_regression.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -# Copyright 2019 Sophos Limited +# Copyright 2019-2021 Sophos Limited # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in # compliance with the License. @@ -15,30 +15,27 @@ """ This script is for testing regressions after changes to Sophos SIEM. - Requirements - - Python 2.7 (ActivePython recommended on Windows) - - Python 3.5 (ActivePython recommended on Windows) + - Python 3.5+ (ActivePython recommended on Windows) - pycef module to decode CEF output. (https://github.com/DavidJBianco/pycef/) There is no equivalent for keyvalue, so that test is expected to fail at least half the time. - Some typical events in Central. Caveats: - Events arriving while the script is running may cause failures. - """ -from ConfigParser import ConfigParser +import configparser as ConfigParser import unittest import shutil import tempfile import os import json import re -import urllib2 +from urllib.request import Request, urlopen import glob -from StringIO import StringIO +import io from zipfile import ZipFile from subprocess import Popen, PIPE import pycef @@ -46,33 +43,38 @@ def find_python(ver): """Try to find Python of the given version. Attempt to work on Windows and Linux""" - win_python2 = glob.glob("c:\\python%d*\\python.exe" % ver) # assume Active Python on Windows - if win_python2: - return win_python2[0] + win_python = glob.glob( + "c:\\python%d*\\python.exe" % ver + ) # assume Active Python on Windows + if win_python: + return win_python[0] else: return "python%d" % ver class SIEMRunner: """Manage SIEM executions in a clean temporary directory, collect results""" + def __init__(self, name): self.root = tempfile.mkdtemp(prefix=name, dir=".") self.cfg_path = os.path.join(self.root, "config.ini") - self.keyval_rex = re.compile(r"^(?P\d\d\d\d-\d\d-\d\dT\d\d:\d\d:\d\d\.\d+Z)\s(?P.+)$") + self.keyval_rex = re.compile( + r"^(?P\d\d\d\d-\d\d-\d\dT\d\d:\d\d:\d\d\.\d+Z)\s(?P.+)$" + ) def set_config(self, name, value): """Set a config value in the config.ini file""" - config = ConfigParser() + config = ConfigParser.ConfigParser() config.read(self.cfg_path) - config.set('login', name, value) - with open(self.cfg_path, 'w') as fp: + config.set("login", name, value) + with open(self.cfg_path, "w") as fp: config.write(fp) def get_config(self, name): """Read a config value from the config.ini file""" - config = ConfigParser() + config = ConfigParser.ConfigParser() config.read(self.cfg_path) - return config.get('login', name) + return config.get("login", name) def run_python(self, pypath, *args): """Run the script with given python interpreter (pypath) supplying *args""" @@ -90,11 +92,11 @@ def run_python_version(self, ver, *args): def reset_state(self): """Remove state files, start over""" - os.unlink(os.path.join(self.root, "state", "siem_lastrun_events.obj")) + os.unlink(os.path.join(self.root, "state", "siem_sophos.json")) def get_results(self): """ - Find out where the results went, read, parse and return them if possible. + Find out where the results went, read, parse and return them if possible. """ ofile = self.get_config("filename") out = [] @@ -118,9 +120,11 @@ def get_results(self): def get_release(): """Return the latest version of Sophos SIEM from github. Cache it in the file 'master.zip'""" - zip_location = "https://github.com/sophos/Sophos-Central-SIEM-Integration/archive/master.zip" + zip_location = ( + "https://github.com/sophos/Sophos-Central-SIEM-Integration/archive/master.zip" + ) if not os.path.exists("master.zip"): - fp = urllib2.urlopen(zip_location) + fp = urlopen(Request(zip_location, method="HEAD")) zip_data = fp.read(int(fp.headers["Content-Length"])) fp.close() open("master.zip", "wb").write(zip_data) @@ -129,7 +133,7 @@ def get_release(): def write_from_zip(zip_data, root): """Given a zip file as data, write the contents to the given root dir""" - zf = StringIO(zip_data) + zf = io.BytesIO(zip_data) input_zip = ZipFile(zf) for name in input_zip.namelist(): data = input_zip.read(name) @@ -162,22 +166,22 @@ class TestCompareOutput(BaseTest): """Test that old and new versions of the software are the same""" def setUp(self): - # Create the original version install - self.orig_runner = SIEMRunner("orig_py2") + self.orig_runner = SIEMRunner("orig_py3") self.zip_data = get_release() write_from_zip(self.zip_data, self.orig_runner.root) # Create the new version install - self.new_runner = SIEMRunner("new_py2") + self.new_runner = SIEMRunner("new_py3") write_from_dir(".", self.new_runner.root) # Configure downloaded (original) version with the token from the version in cwd orig_token = self.orig_runner.get_config("token_info") - self.assertTrue(orig_token.startswith("<")) # we've got the as-shipped token text. Expected. - + self.assertTrue( + orig_token.startswith("<") + ) # we've got the as-shipped token text. Expected. new_token = self.new_runner.get_config("token_info") - self.assertTrue(new_token.startswith("url:")) # Check the substitution + self.assertTrue(new_token.startswith("url:")) # Check the substitution self.orig_runner.set_config("token_info", new_token) @@ -199,34 +203,34 @@ def RunBoth(self, ver, *args): def testJson(self): """Test the json output is identical between versions""" self.configure_all("format", "json") - orig, new = self.RunBoth(2) + orig, new = self.RunBoth(3) self.assertEqual(orig, new) def testCEF(self): """Test the CEF output is identical between versions""" self.configure_all("format", "cef") - orig, new = self.RunBoth(2) + orig, new = self.RunBoth(3) self.assertEqual(orig, new) def XXtestKeyValue(self): """ - Field order is dependent on Python dict.items() iteration order which isn't consistent between runs. - This means keys can appear in any order, and without full parsing of keyvalue data, the comparison - can't be done. - If you need this test, comment it in (Remove XX above) and keep running it. It will pass approx 50% - of the time. + Field order is dependent on Python dict.items() iteration order which isn't consistent between runs. + This means keys can appear in any order, and without full parsing of keyvalue data, the comparison + can't be done. + If you need this test, comment it in (Remove XX above) and keep running it. + It will pass approx 50% of the time. """ self.configure_all("format", "keyvalue") - orig, new = self.RunBoth(2) + orig, new = self.RunBoth(3) self.assertEqual(orig, new) def testJsonDifferentPython(self): """ - Run the new version with Python3, old version with Python2 and compare output. - This should result in the same output. + Run the new version and old version with Python3 and compare output. + This should result in the same output. """ self.configure_all("format", "json") - self.orig_runner.run_python_version(2) + self.orig_runner.run_python_version(3) orig = self.orig_runner.get_results() self.new_runner.run_python_version(3) @@ -235,9 +239,8 @@ def testJsonDifferentPython(self): class TestNewFunctionality(BaseTest): - def setUp(self): - self.runner = SIEMRunner("new_py2") + self.runner = SIEMRunner("new_py3") write_from_dir(".", self.runner.root) self.runner.set_config("format", "json") self.runner.set_config("filename", "result.txt") @@ -249,9 +252,9 @@ def tearDown(self): def testRunTwice(self): """Run the program twice, make sure the results file doesn't change in size""" - self.runner.run_python_version(2) + self.runner.run_python_version(3) first_run = self.runner.get_results() - self.runner.run_python_version(2) + self.runner.run_python_version(3) second_run = self.runner.get_results() self.assertEqual(first_run, second_run) @@ -261,13 +264,12 @@ def testRunWithStaleResults(self): logdir = os.path.join(self.runner.root, "log") os.makedirs(logdir) ofile = os.path.join(logdir, self.runner.get_config("filename")) - - marker = '["SOME OLD JSON LOG DATA"]\r\n' + marker = '["SOME OLD JSON LOG DATA"]\r\n'.encode() with open(ofile, "wb") as fp: fp.write(marker) before_size = os.stat(ofile).st_size self.assertEqual(before_size, len(marker)) - self.runner.run_python_version(2) + self.runner.run_python_version(3) after_size = os.stat(ofile).st_size new_marker = open(ofile, "rb").read(len(marker)) self.assertEqual(marker, new_marker) @@ -275,21 +277,20 @@ def testRunWithStaleResults(self): def testLightMode(self): noisy = ["Event::Endpoint::UpdateSuccess"] - self.runner.run_python_version(2, "--light") + self.runner.run_python_version(3, "--light") for i in self.runner.get_results(): # We know for sure this event will always be noisy. Could check for the others. self.assertTrue(i["type"] not in noisy) - - # Make sure the + # Make sure the self.runner.reset_state() - self.runner.run_python_version(2) + self.runner.run_python_version(3) found = False for i in self.runner.get_results(): - # We know for sure this event will always be noisy. Could check for the others. - if i["type"] in noisy: + # We know for sure this event will always be noisy. Could check for the others. + if i["type"] in noisy: found = True - self.assertTrue(found) # we expect this event to appear all over the place. + self.assertTrue(found) # we expect this event to appear all over the place. -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit/test_api_client.py b/tests/unit/test_api_client.py new file mode 100644 index 0000000..29888de --- /dev/null +++ b/tests/unit/test_api_client.py @@ -0,0 +1,442 @@ +# Copyright 2019-2021 Sophos Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. +# You may obtain a copy of the License at: http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing permissions and limitations under the +# License. +# + +""" + Unit tests for Sophos SIEM Client. + + Requirements + - Python 3.5+ (ActivePython recommended on Windows) +""" + +import os +import shutil +import api_client +import sys +import unittest +import json +import time + +from mock import MagicMock +from mock import patch + + +class Options: + def __init__(self): + self.quiet = False + self.debug = True + self.light = True + self.since = False + + +class State: + def __init__(self): + self.state_data = {} + + +class Config: + def __init__(self): + self.filename = "syslog" + self.facility = "daemon" + self.address = "localhost:00" + self.socktype = "udp" + self.format = "json" + self.client_id = "" + self.client_secret = "" + self.tenant_id = "" + self.token_info = "" + self.auth_url = "" + self.api_host = "" + + +class TestApiClient(unittest.TestCase): + def setUp(self): + self.LOGGER_MOCK = MagicMock() + api_client.SIEM_LOGGER = self.LOGGER_MOCK + options = Options() + state = State() + config = Config() + api_client.urlrequest.HTTPSHandler = MagicMock() + api_client.urlrequest.build_opener = MagicMock() + os.environ["SOPHOS_SIEM_HOME"] = "fake_sophos_siem_home" + self.api_client = api_client.ApiClient( + "/siem/v1/events", options, config, state + ) + + def tearDown(self): + if os.path.exists("fake_sophos_siem_home/log"): + shutil.rmtree("fake_sophos_siem_home") + + @patch("sys.stderr.write") + def test_log(self, mock_sys_write): + self.api_client.log("test") + mock_sys_write.assert_called_once() + mock_sys_write.assert_called_with("test\n") + + def test_get_syslog_facilities(self): + result = self.api_client.get_syslog_facilities() + self.assertIn("auth", result) + + @patch("time.sleep") + def test_jitter(self, mock_time): + result = self.api_client.jitter() + mock_time.assert_called_once() + + @patch("api_client.logging.StreamHandler") + def test_add_siem_logeer_handler_stdout(self, mock_handler): + self.api_client.config.filename = "stdout" + self.api_client.add_siem_logeer_handler("/fake_tmp/") + mock_handler.assert_called_once() + + @patch("api_client.logging.FileHandler") + def test_add_siem_logeer_handler_other(self, mock_handler): + self.api_client.config.filename = "other" + self.api_client.add_siem_logeer_handler("/fake_tmp/") + mock_handler.assert_called_once() + + @patch("api_client.calendar.timegm") + def test_get_past_datetime(self, mock_calender): + self.api_client.get_past_datetime(12) + mock_calender.assert_called_once() + + @patch("api_client.config.Token") + @patch("api_client.urlrequest.Request") + def test_get_alerts_or_events_with_token(self, mock_urlrequest, mock_token): + mock_event_response = { + "has_more": False, + "next_cursor": "VjJfQ1VSU09SfDITESTETSTETtMDFUMTg6MjU6NDEuNjA2Wg==", + "items": [], + } + self.api_client.make_token_request = MagicMock() + self.api_client.make_token_request.return_value = mock_event_response + response = self.api_client.get_alerts_or_events() + self.assertEqual(response["next_cursor"], mock_event_response["next_cursor"]) + self.assertEqual(len(response["items"]), 0) + self.api_client.options.since = 10 + response = self.api_client.get_alerts_or_events() + self.assertEqual(response["next_cursor"], mock_event_response["next_cursor"]) + self.assertEqual(len(response["items"]), 0) + + @patch("sys.stderr.write") + @patch("api_client.urlrequest.Request") + def test_get_alerts_or_events_with_credentials(self, mock_urlrequest, sys_write): + mock_event_response = { + "has_more": False, + "next_cursor": "TESJfQ1VSU09SfDITESTETSTETtMDFUMTg6MjU6NDEuNjA2Wg==", + "items": [], + } + mock_tenant_response = { + "has_more": False, + "next_cursor": "TEST1VSU09SfDITESTETSTETtMDFUMTg6MjU6NDEuNjA2Wg==", + "items": [{"id": 1, "idType": "test", "apiHost": "http://localhost"}], + } + self.api_client.config.client_id = "test_client_id" + self.api_client.config.client_secret = "test_client_secret" + self.api_client.get_tenants_from_sophos = MagicMock() + self.api_client.get_tenants_from_sophos.return_value = mock_tenant_response + self.api_client.make_credentials_request = MagicMock() + self.api_client.make_credentials_request.return_value = mock_event_response + response = self.api_client.get_alerts_or_events() + self.assertEqual(response["next_cursor"], mock_event_response["next_cursor"]) + self.assertEqual(len(response["items"]), 0) + self.api_client.get_tenants_from_sophos.return_value = {"error": "error"} + + with self.assertRaises(Exception) as context: + self.api_client.get_alerts_or_events() + sys_write.assert_called_with("Error :: error\n") + + @patch("api_client.urlrequest.Request") + def test_call_endpoint(self, mock_urlrequest): + mock_response = { + "has_more": False, + "next_cursor": "VjJfQ1VSU09SfDIwMTktMDQtMDFUMTg6MjU6NDEuNjA2Wg==", + "items": [ + { + "when": "2019-04-01T15:11:09.759Z", + "id": "cbaff14f-a36b-46bd-8e83-6017ad79cdef", + "customer_id": "816f36ee-dd2e-4ccd-bb12-cea766c28ade", + "severity": "low", + "created_at": "2019-04-01T15:11:09.984Z", + "source_info": {"ip": "10.1.39.32"}, + "endpoint_type": "server", + "endpoint_id": "c80c2a87-42f2-49b2-bab7-5031b69cd83e", + "origin": None, + "type": "Event::Endpoint::Registered", + "location": "mock_Mercury_1", + "source": "n/a", + "group": "PROTECTION", + "name": "New server registered: mock_Mercury_1", + }, + { + "when": "2019-04-01T15:11:41.000Z", + "id": "5bc48f19-3905-4f72-9f79-cd381c8e92ce", + "customer_id": "816f36ee-dd2e-4ccd-bb12-cea766c28ade", + "severity": "medium", + "created_at": "2019-04-01T15:11:41.053Z", + "source_info": {"ip": "10.1.39.32"}, + "endpoint_type": "server", + "endpoint_id": "c80c2a87-42f2-49b2-bab7-5031b69cd83e", + "origin": None, + "type": "Event::Endpoint::Threat::Detected", + "location": "mock_Mercury_1", + "source": "n/a", + "group": "MALWARE", + "name": "Malware detected: 'Eicar-AV-Test' at 'C:\\Program Files (x86)\\Trojan Horse\\bin\\eicar.com'", + }, + ], + } + self.api_client.request_url = MagicMock() + self.api_client.request_url.return_value = json.dumps(mock_response) + response = self.api_client.call_endpoint("http://localhost", None, "") + + self.assertEqual(response["next_cursor"], mock_response["next_cursor"]) + self.assertEqual(len(response["items"]), 2) + + @patch("api_client.urlrequest.Request") + def test_request_url(self, mock_urlrequest): + mock_event_response = { + "has_more": False, + "next_cursor": "VjJfQ1VSU09SfDITESTETSTETtMDFUMTg6MjU6NDEuNjA2Wg==", + "items": [], + } + self.api_client.opener.open = MagicMock() + self.api_client.opener.open.return_value.read.return_value = mock_event_response + response = self.api_client.request_url("http://localhost", None, "") + self.assertEqual(response["has_more"], mock_event_response["has_more"]) + self.assertEqual(response["next_cursor"], mock_event_response["next_cursor"]) + self.assertEqual(len(response["items"]), 0) + + def test_get_alerts_or_events_req_args(self): + self.api_client.options.light = True + params = {"limit": 1000, "cursor": False} + response = self.api_client.get_alerts_or_events_req_args(params) + self.assertIn( + "limit=1000&cursor=False&exclude_types=Event::Endpoint::NonCompliant", + response, + ) + self.api_client.options.light = False + response = self.api_client.get_alerts_or_events_req_args(params) + self.assertEqual(response, "limit=1000&cursor=False") + + def test_make_token_request(self): + mock_response = { + "has_more": False, + "next_cursor": "TESJfQ1VSU09SfDITESTETSTETtMDFUMTg6MjU6NDEuNjA2Wg==", + "items": [ + { + "severity": "high", + "threat": "TEST", + "endpoint_id": "123-131-3131-31313", + "endpoint_type": "test_type", + "source_info": {"ip": "0.0.0.0"}, + "customer_id": "dadaf-test1213-sfsf-test", + "name": "test", + "id": "test-1213-1213-1213-1213", + "type": "test::testpoint::testfailed", + "group": "test", + "datastream": "test", + "end": "1999-03-24T12:45:33.273Z", + "duid": "test", + "rt": "1999-03-25T12:45:35.521Z", + "dhost": "test", + "suser": "test", + } + ], + } + mock_empty_items_response = { + "has_more": False, + "next_cursor": "TESJfQ1VSU09SfDITESTETSTETtMDFUMTg6MjU6NDEuNjA2Wg==", + "items": [ + ], + } + self.api_client.state.save_state = MagicMock() + self.api_client.call_endpoint = MagicMock() + self.api_client.call_endpoint.return_value = mock_response + response = self.api_client.make_token_request(False, "events", MagicMock()) + self.assertEqual(list(response), mock_response["items"]) + + self.api_client.call_endpoint.return_value = mock_empty_items_response + response = self.api_client.make_token_request(False, "events", MagicMock()) + self.assertEqual(list(response), mock_empty_items_response["items"]) + + def test_make_credentials_request(self): + tenant_response = { + "access_token": "test_access_token", + "has_more": False, + "next_cursor": "TEST1VSU09SfDITESTETSTETtMDFUMTg6MjU6NDEuNjA2Wg==", + "items": [{"id": "1", "apiHost": "http://localhost"}], + } + mock_response = { + "has_more": False, + "next_cursor": "TESJfQ1VSU09SfDITESTETSTETtMDFUMTg6MjU6NDEuNjA2Wg==", + "items": [ + { + "severity": "high", + "threat": "TEST", + "endpoint_id": "123-131-3131-31313", + "endpoint_type": "test_type", + "source_info": {"ip": "0.0.0.0"}, + "customer_id": "dadaf-test1213-sfsf-test", + "name": "test", + "id": "test-1213-1213-1213-1213", + "type": "test::testpoint::testfailed", + "group": "test", + "datastream": "test", + "end": "1999-03-24T12:45:33.273Z", + "duid": "test", + "rt": "1999-03-25T12:45:35.521Z", + "dhost": "test", + "suser": "test", + } + ], + } + mock_empty_items_response = { + "has_more": False, + "next_cursor": "TESJfQ1VSU09SfDITESTETSTETtMDFUMTg6MjU6NDEuNjA2Wg==", + "items": [ + ], + } + self.api_client.state.save_state = MagicMock() + self.api_client.call_endpoint = MagicMock() + self.api_client.call_endpoint.return_value = mock_response + response = self.api_client.make_credentials_request( + False, "events", tenant_response + ) + self.assertEqual(list(response), mock_response["items"]) + self.api_client.state_data = {"tenants": {"1": {"events": time.time() - 120}}} + response = self.api_client.make_credentials_request( + False, "events", tenant_response + ) + self.assertEqual(list(response), mock_response["items"]) + + + self.api_client.call_endpoint.return_value = mock_empty_items_response + response = self.api_client.make_credentials_request( + False, "events", tenant_response + ) + self.assertEqual(list(response), mock_empty_items_response["items"]) + + + def test_get_partner_tenants_from_sophos(self): + whoami_response = {"id": "1", "idType": "partner"} + token_response = { + "access_token": "Test_Toekn", + } + partner_response = {"id": "1"} + self.api_client.get_sophos_jwt = MagicMock() + self.api_client.get_sophos_jwt.return_value = token_response + self.api_client.get_whoami_data = MagicMock() + self.api_client.get_whoami_data.return_value = whoami_response + self.api_client.get_partner_organization_tenants = MagicMock() + self.api_client.get_partner_organization_tenants.return_value = partner_response + response = self.api_client.get_tenants_from_sophos() + self.assertEqual(response, partner_response) + + def test_get_tenants_from_sophos_jwt_error(self): + token_error_response = {"error": "error"} + self.api_client.get_sophos_jwt = MagicMock() + self.api_client.get_sophos_jwt.return_value = token_error_response + response = self.api_client.get_tenants_from_sophos() + self.assertEqual(response, token_error_response) + + @patch("sys.stderr.write") + def test_get_tenants_from_sophos_jwt(self, mock_sys_write): + whoami_response = {"id": "1", "idType": "tenant"} + token_response = { + "access_token": "Test_Toekn", + } + self.api_client.config.tenant_id = "1" + self.api_client.get_sophos_jwt = MagicMock() + self.api_client.get_sophos_jwt.return_value = token_response + self.api_client.get_whoami_data = MagicMock() + self.api_client.get_whoami_data.return_value = whoami_response + response = self.api_client.get_tenants_from_sophos() + self.assertEqual( + response, + {"access_token": "Test_Toekn", "items": [{"id": "1", "idType": "tenant"}]}, + ) + self.api_client.config.tenant_id = "11" + with self.assertRaises(Exception) as context: + self.api_client.get_tenants_from_sophos() + self.assertTrue( + "Configuration file mention tenant id not matched with whoami data tenant id" + in str(context.exception) + ) + + def test_get_tenants_from_sophos_empty_whoami(self): + token_response = { + "access_token": "Test_Toekn", + } + whoami_response = {} + self.api_client.get_sophos_jwt = MagicMock() + self.api_client.get_sophos_jwt.return_value = token_response + self.api_client.get_whoami_data = MagicMock() + self.api_client.get_whoami_data.return_value = whoami_response + response = self.api_client.get_tenants_from_sophos() + self.assertEqual(response, {}) + + def test_get_sophos_jwt(self): + token_response = {"access_token": "Test_Toekn", "expires_in": time.time()} + self.api_client.config.client_id = "test_client_id" + self.api_client.config.client_secret = "test_client_secret" + self.api_client.state.save_state = MagicMock() + self.api_client.request_url = MagicMock() + self.api_client.request_url.return_value = json.dumps(token_response) + response = self.api_client.get_sophos_jwt() + self.assertEqual(response, token_response) + + def test_get_cache_sophos_jwt(self): + token_response = {"access_token": "Test_Toekn", "expires_in": time.time()} + self.api_client.config.client_id = "test_client_id" + self.api_client.config.client_secret = "test_client_secret" + self.api_client.state_data = { + "account": { + "test_client_id": { + "jwtExpiresAt": time.time() + 120, + "jwt": "old_token", + } + } + } + response = self.api_client.get_sophos_jwt() + self.assertEqual(response, {"access_token": "old_token"}) + + def test_get_whoami_data(self): + whoami_response = {"id": "1", "idType": "tenant", "apiHost": "http://localhost"} + self.api_client.config.api_host = "http://localhost" + self.api_client.config.client_secret = "test_client_secret" + self.api_client.state.save_state = MagicMock() + self.api_client.request_url = MagicMock() + self.api_client.request_url.return_value = json.dumps(whoami_response) + response = self.api_client.get_whoami_data("test_token") + self.assertEqual(response, whoami_response) + + def test_get_partner_organization_tenants(self): + whoami_response = { + "id": "1", + "idType": "partner", + "apiHosts": {"global": "http://localhost"}, + } + + tenant_response = { + "has_more": False, + "next_cursor": "TEST1VSU09SfDITESTETSTETtMDFUMTg6MjU6NDEuNjA2Wg==", + "items": [{"id": "1", "idType": "test", "apiHost": "http://localhost"}], + } + self.api_client.config.api_host = "http://localhost" + self.api_client.config.tenant_id = "1" + self.api_client.config.client_secret = "test_client_secret" + self.api_client.state.save_state = MagicMock() + self.api_client.request_url = MagicMock() + self.api_client.request_url.return_value = json.dumps(tenant_response) + response = self.api_client.get_partner_organization_tenants( + whoami_response, "test_token" + ) + self.assertEqual(response, tenant_response) diff --git a/tests/unit/test_call_endpoint.py b/tests/unit/test_call_endpoint.py deleted file mode 100644 index 7151321..0000000 --- a/tests/unit/test_call_endpoint.py +++ /dev/null @@ -1,217 +0,0 @@ -# Copyright 2019 Sophos Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in -# compliance with the License. -# You may obtain a copy of the License at: http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software distributed under the License is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -# implied. See the License for the specific language governing permissions and limitations under the -# License. -# - -""" - Unit tests for Sophos SIEM. - Requirements - - Python 2.7 (ActivePython recommended on Windows) - - Python 3.5 (ActivePython recommended on Windows) - -""" - -import os -import unittest -import siem - -import mock -import json - - -class CallEndPointTest(unittest.TestCase): - - @mock.patch('siem.store_state') - @mock.patch('siem.config.Token') - def test_data_stream_in_event(self, - mock_cf_token, - mock_store_state - ): - # Setup - # Sample event - mock_event_response = { - "has_more": False, - "next_cursor": "VjJfQ1VSU09SfDIwMTktMDQtMDFUMTg6MjU6NDEuNjA2Wg==", - "items": [ - { - "when": "2019-04-01T15:11:09.759Z", - "id": "cbaff14f-a36b-46bd-8e83-6017ad79cdef", - "customer_id": "816f36ee-dd2e-4ccd-bb12-cea766c28ade", - "severity": "low", - "created_at": "2019-04-01T15:11:09.984Z", - "source_info": { - "ip": "10.1.39.32" - }, - "endpoint_type": "server", - "endpoint_id": "c80c2a87-42f2-49b2-bab7-5031b69cd83e", - "origin": None, - "type": "Event::Endpoint::Registered", - "location": "mock_Mercury_1", - "source": "n/a", - "group": "PROTECTION", - "name": "New server registered: mock_Mercury_1" - }, - { - "when": "2019-04-01T15:11:41.000Z", - "id": "5bc48f19-3905-4f72-9f79-cd381c8e92ce", - "customer_id": "816f36ee-dd2e-4ccd-bb12-cea766c28ade", - "severity": "medium", - "created_at": "2019-04-01T15:11:41.053Z", - "source_info": { - "ip": "10.1.39.32" - }, - "endpoint_type": "server", - "endpoint_id": "c80c2a87-42f2-49b2-bab7-5031b69cd83e", - "origin": None, - "type": "Event::Endpoint::Threat::Detected", - "location": "mock_Mercury_1", - "source": "n/a", - "group": "MALWARE", - "name": "Malware detected: 'Eicar-AV-Test' at 'C:\\Program Files (x86)\\Trojan Horse\\bin\\eicar.com'" - } - ] - } - - events = [] - # Run - try: - with mock.patch('siem.request_url') as mock_request_url: - # mocking the request that retrieves events from SOA - mock_request_url.return_value = json.dumps(mock_event_response) - # call_endpoint uses yield method of python and returns each alert to the caller for - # additional processing. Here its just appended to a list - for e in siem.call_endpoint(mock.Mock(), siem.EVENTS_V1, False, False, 'fake_state_file', mock_cf_token): - events.append(e) - - except Exception as ex: - print ex - - # Verify - self.assertEqual(len(events), 2) - self.assertEqual(events[0]["datastream"], siem.EVENT_TYPE) - - @mock.patch('siem.store_state') - @mock.patch('siem.config.Token') - def test_data_stream_in_alert(self, - mock_cf_token, - mock_store_state - ): - # Setup - # Sample alert - mock_alert_response = { - "has_more": False, - "next_cursor": "MHwyMDE5LTA0LTAxVDIxOjI2OjQ5LjIxOVo=", - "items": [ - { - "severity": "high", - "when": "2019-04-01T16:11:10.487Z", - "threat": None, - "event_service_event_id": "d2fbaebe-c169-405e-946c-a2afbfb65ce2", - "id": "d2fbaebe-c169-405e-946c-a2afbfb65ce2", - "info": None, - "created_at": "2019-04-01T16:11:10.557Z", - "customer_id": "816f36ee-dd2e-4ccd-bb12-cea766c28ade", - "threat_cleanable": None, - "data": { - "created_at": 1554135070527, - "endpoint_id": "c80c2a87-42f2-49b2-bab7-5031b69cd83e", - "endpoint_java_id": "c80c2a87-42f2-49b2-bab7-5031b69cd83e", - "endpoint_platform": "windows", - "endpoint_type": "server", - "event_service_id": "d2fbaebe-c169-405e-946c-a2afbfb65ce2", - "inserted_at": 1554135070527, - "source_info": { - "ip": "10.1.39.32" - }, - "user_match_id": "5ca22a0de5a7400deb1ab0bb" - }, - "type": "Event::Endpoint::NotProtected", - "location": "mock_Mercury_1", - "description": "Failed to protect server: mock_Mercury_1", - "source": "n/a" - }, - { - "severity": "high", - "when": "2019-04-01T16:14:17.147Z", - "threat": None, - "event_service_event_id": "69086f43-d619-4d03-a110-9a8cee3436f7", - "id": "69086f43-d619-4d03-a110-9a8cee3436f7", - "info": None, - "created_at": "2019-04-01T16:14:17.330Z", - "customer_id": "816f36ee-dd2e-4ccd-bb12-cea766c28ade", - "threat_cleanable": None, - "data": { - "created_at": 1554135257294, - "endpoint_id": "7a489a3a-0152-4b01-b1cc-c10d7f84f0bc", - "endpoint_java_id": "7a489a3a-0152-4b01-b1cc-c10d7f84f0bc", - "endpoint_platform": "windows", - "endpoint_type": "computer", - "event_service_id": "69086f43-d619-4d03-a110-9a8cee3436f7", - "inserted_at": 1554135257294, - "source_info": { - "ip": "10.1.39.32" - }, - "user_match_id": "5ca22ac8e5a7400deb1ab0bd" - }, - "type": "Event::Endpoint::NotProtected", - "location": "Lightning-oxidtdinku", - "description": "Failed to protect computer: Lightning-oxidtdinku", - "source": "Lightning-g7n7wdv611\\Lightning" - }, - { - "severity": "high", - "when": "2019-04-01T15:11:41.000Z", - "threat": "Eicar-AV-Test", - "event_service_event_id": "30acdaed-5387-4a39-b7cd-d7b9ac2a8c0f", - "id": "30acdaed-5387-4a39-b7cd-d7b9ac2a8c0f", - "info": None, - "created_at": "2019-04-01T15:11:41.253Z", - "customer_id": "816f36ee-dd2e-4ccd-bb12-cea766c28ade", - "threat_cleanable": False, - "data": { - "created_at": 1554131501192, - "endpoint_id": "c80c2a87-42f2-49b2-bab7-5031b69cd83e", - "endpoint_java_id": "c80c2a87-42f2-49b2-bab7-5031b69cd83e", - "endpoint_platform": "windows", - "endpoint_type": "server", - "event_service_id": "30acdaed-5387-4a39-b7cd-d7b9ac2a8c0f", - "inserted_at": 1554131501192, - "source_info": { - "ip": "10.1.39.32" - }, - "threat_id": "5ca22a2c352ea40df121ea9c", - "user_match_id": "5ca22a0de5a7400deb1ab0bb" - }, - "type": "Event::Endpoint::Threat::CleanupFailed", - "location": "mock_Mercury_1", - "description": "Manual cleanup required: 'Eicar-AV-Test' at 'C:\\Program Files (x86)\\Trojan Horse\\bin\\eicar.com'", - "source": "n/a" - } - ] - } - - alerts = [] - # Run - try: - with mock.patch('siem.request_url') as mock_request_url: - # mocking the request that retrieves alerts from SOA - mock_request_url.return_value = json.dumps(mock_alert_response) - # call_endpoint uses yield method of python and returns each alert to the caller for - # additional processing. Here its just appended to a list - for a in siem.call_endpoint(mock.Mock(), siem.ALERTS_V1, False, False, 'fake_state_file', mock_cf_token): - alerts.append(a) - - except Exception as ex: - print ex - - # Verify - self.assertEqual(len(alerts), 3) - self.assertEqual(alerts[0]["datastream"], siem.ALERT_TYPE) - diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py new file mode 100644 index 0000000..a0ddc7d --- /dev/null +++ b/tests/unit/test_config.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python + +# Copyright 2019-2021 Sophos Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. +# You may obtain a copy of the License at: http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing permissions and limitations under the +# License. +# + +import unittest +import shutil +import tempfile +import os +import config + + +class TestConfig(unittest.TestCase): + """Test Config file items are exposed as attributes on config object""" + + def setUp(self): + self.tmpdir = tempfile.mkdtemp(prefix="config_test", dir=".") + + def tearDown(self): + if os.path.exists(self.tmpdir): + shutil.rmtree(self.tmpdir) + + def testReadingWhenAttributeExists(self): + cfg_path = os.path.join(self.tmpdir, "config.ini") + with open(cfg_path, "wb") as fp: + fp.write("[login]\ntoken_info = MY_TOKEN\n".encode("utf-8")) + cfg = config.Config(cfg_path) + self.assertEqual(cfg.token_info, "MY_TOKEN") + + +class TestToken(unittest.TestCase): + """Test the token gets parsed""" + + def testParse(self): + txt = "url: https://anywhere.com/api, x-api-key: random, Authorization: Basic KJNKLJNjklNLKHB= " + t = config.Token(txt) + self.assertEqual(t.url, "https://anywhere.com/api") + self.assertEqual(t.api_key, "random") + self.assertEqual(t.authorization, "Basic KJNKLJNjklNLKHB=") diff --git a/tests/unit/test_name_mapping.py b/tests/unit/test_name_mapping.py new file mode 100644 index 0000000..a5c8c94 --- /dev/null +++ b/tests/unit/test_name_mapping.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python + +# Copyright 2019-2021 Sophos Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. +# You may obtain a copy of the License at: http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing permissions and limitations under the +# License. +# + +import unittest +import copy +import name_mapping + + +def contains(dict_outer, dict_inner): + return all(item in dict_outer.items() for item in dict_inner.items()) + + +class TestNameExtraction(unittest.TestCase): + """Test logging output""" + + def setUp(self): + self.output = [] + + def tearDown(self): + pass + + def log(self, s): + self.output.append(s) + + def testUpdateNameDLPValid(self): + """DLP event with data that can be extracted""" + data = { + "type": "Event::Endpoint::DataLossPreventionUserAllowed", + "name": u"An \u2033allow transfer on acceptance by user\u2033 action was taken. " + u"Username: WIN10CLOUD2\\Sophos Rule names: \u2032test\u2032 User action: File open " + u"Application Name: Google Chrome Data Control action: Allow " + u"File type: Plain text (ASCII/UTF-8) File size: 36 " + u"Source path: C:\\Users\\Sophos\\Desktop\\test.txt", + } + expected = { + "type": "Event::Endpoint::DataLossPreventionUserAllowed", + "name": "allow transfer on acceptance by user", + "user": "WIN10CLOUD2\\Sophos", + "rule": "test", + "user_action": "File open", + "app_name": "Google Chrome", + "action": "Allow", + "file_type": "Plain text (ASCII/UTF-8)", + "file_size": "36", + "file_path": "C:\\Users\\Sophos\\Desktop\\test.txt", + } + name_mapping.update_fields(self.log, data) + self.assertTrue(all(item in data.items() for item in expected.items())) + self.assertEqual(len(self.output), 0) + + def testUpdateNameThreatValid(self): + """Threat event with data that can be extracted""" + data = { + "type": "Event::Endpoint::Threat::CleanedUp", + "name": u"Threat 'EICAR' in 'myfile.com' ", + } + expected = { + "type": "Event::Endpoint::Threat::CleanedUp", + "name": u"EICAR", + "filePath": "myfile.com", + "detection_identity_name": "EICAR", + } + name_mapping.update_fields(self.log, data) + self.assertTrue(contains(data, expected)) # expected data present + self.assertEqual(len(self.output), 0) # no error + + def testUpdateNameInvalid(self): + """A known type, but information can't be extracted (regex mismatch)""" + data = { + "type": "Event::Endpoint::DataLossPreventionUserAllowed", + "name": u"XXXX Garbage data XXXX", + } + before = copy.copy(data) + name_mapping.update_fields(self.log, data) + self.assertEqual( + len(self.output), 1 + ) # a line of error output, when the function bails. + self.assertEqual(data, before) # ... and data remains unchanged + + def testUpdateNameFromDescription(self): + """Ensure the name gets updated from the description, if present""" + data = {"type": "", "description": "XXX"} + expected = copy.copy(data) + expected["name"] = "XXX" + name_mapping.update_fields(self.log, data) + self.assertEqual(data, expected) + + def testInvalidType(self): + """Ensure that nothing gets changed when the type isn't recognised""" + data = {"type": "<>", "name": "some name"} + expected = copy.copy(data) + name_mapping.update_fields(self.log, data) + self.assertEqual(len(self.output), 0) # not considered an error + self.assertEqual(data, expected) + + def testSkippedType(self): + """Ensure that entry is skipped if it's to be ignored.""" + # First find an event type that is set to 'None' + toskip = None + for k, v in name_mapping.TYPE_HANDLERS.items(): + if not v: + toskip = k + break + data = {"type": toskip, "name": "some name"} + expected = copy.copy(data) + name_mapping.update_fields(self.log, data) + self.assertEqual(len(self.output), 0) # not considered an error + self.assertEqual(data, expected) diff --git a/tests/unit/test_siem.py b/tests/unit/test_siem.py index c8a6062..4db1246 100755 --- a/tests/unit/test_siem.py +++ b/tests/unit/test_siem.py @@ -1,4 +1,4 @@ -# Copyright 2019 Sophos Limited +# Copyright 2019-2021 Sophos Limited # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in # compliance with the License. @@ -15,8 +15,7 @@ Unit tests for Sophos SIEM. Requirements - - Python 2.7 (ActivePython recommended on Windows) - - Python 3.5 (ActivePython recommended on Windows) + - Python 3.5+ (ActivePython recommended on Windows) """ import os @@ -28,77 +27,7 @@ from mock import patch -class CreateSIEMLogHandlerUnitTests(unittest.TestCase): - - @patch('logging.handlers.SysLogHandler') - def test_syslog(self, mock): - # Setup - endpoint_config = { - 'filename': 'syslog', - 'address': 'fake_address', - 'facility': 'auth', - 'socktype': 'udp' - } - - # Run - handler = siem.create_output_handler(endpoint_config) - - # Verify - self.assertIsInstance(handler, type(mock.return_value)) - self.assertEqual(mock.call_count, 1) - mock.assert_called_with('fake_address', 4, 2) - - @patch('logging.handlers.SysLogHandler') - def test_syslog_with_port(self, mock): - # Setup - endpoint_config = { - 'filename': 'syslog', - 'address': 'fake_address:1234', - 'facility': 'cron', - 'socktype': 'tcp' - } - - # Run - handler = siem.create_output_handler(endpoint_config) - - # Verify - self.assertIsInstance(handler, type(mock.return_value)) - self.assertEqual(mock.call_count, 1) - mock.assert_called_with(('fake_address', 1234), 9, 1) - - @patch('logging.StreamHandler') - def test_stdout(self, mock): - # Setup - endpoint_config = { - 'filename': 'stdout', - } - - # Run - handler = siem.create_output_handler(endpoint_config) - - # Verify - self.assertIsInstance(handler, type(mock.return_value)) - self.assertEqual(mock.call_count, 1) - mock.assert_called_with(sys.stdout) - - @patch('logging.FileHandler') - def test_file(self, mock): - # Setup - endpoint_config = { - 'filename': 'fake_filename', - 'log_dir': 'fake_log_dir' - } - - # Run - handler = siem.create_output_handler(endpoint_config) - - # Verify - self.assertIsInstance(handler, type(mock.return_value)) - self.assertEqual(mock.call_count, 1) - mock.assert_called_with(os.path.join('fake_log_dir', 'fake_filename'), 'a', encoding='utf-8') - - -class FormatUnitTests(unittest.TestCase): +class TestSiem(unittest.TestCase): LOGGER_MOCK = None @@ -107,134 +36,167 @@ def setUp(self): siem.SIEM_LOGGER = self.LOGGER_MOCK @patch("name_mapping.update_fields") - def test_write_json_format(self, mock): + def test_write_json_format(self, mock_update_fields): # Setup - results = [{ - 'key': 'value' - }] + results = [{"key": "value"}] # Run siem.write_json_format(results) # Verify - self.assertEqual(mock.call_count, 1) - mock.assert_called_with(siem.log, results[0]) + self.assertEqual(mock_update_fields.call_count, 1) + mock_update_fields.assert_called_with(siem.log, results[0]) self.assertEqual(self.LOGGER_MOCK.info.call_count, 1) - self.LOGGER_MOCK.info.assert_called_with(u'{"key": "value"}\n') + self.LOGGER_MOCK.info.assert_called_with(u'{\n "key": "value"\n}\n') @patch("name_mapping.update_fields") - def test_write_keyvalue_format(self, mock): + def test_write_keyvalue_format(self, mock_update_fields): # Setup - results = [{ - 'rt': 'date' - }] + results = [{"rt": "date"}] # Run siem.write_keyvalue_format(results) # Verify - self.assertEqual(mock.call_count, 1) - mock.assert_called_with(siem.log, results[0]) + self.assertEqual(mock_update_fields.call_count, 1) + mock_update_fields.assert_called_with(siem.log, results[0]) self.assertEqual(self.LOGGER_MOCK.info.call_count, 1) self.LOGGER_MOCK.info.assert_called_with(u'date rt="date";\n') @patch("name_mapping.update_fields") - def test_write_cef_format(self, mock): + def test_write_cef_format(self, mock_update_fields): # Setup - results = [{ - 'key': 'value' - }] + results = [{"key": "value"}] # Run siem.write_cef_format(results) # Verify - self.assertEqual(mock.call_count, 1) - mock.assert_called_with(siem.log, results[0]) + self.assertEqual(mock_update_fields.call_count, 1) + mock_update_fields.assert_called_with(siem.log, results[0]) self.assertEqual(self.LOGGER_MOCK.info.call_count, 1) - self.LOGGER_MOCK.info.assert_called_with(u'CEF:0|sophos|sophos central|1.0|NA|NA|0|key=value\n') - - -class MainUnitTests(unittest.TestCase): - - LOGGER_MOCK = None - - def setUp(self): - self.LOGGER_MOCK = MagicMock() - siem.SIEM_LOGGER = self.LOGGER_MOCK - - @patch("siem.create_log_and_state_dir") - @patch("siem.process_endpoint") - @patch("siem.create_output_handler") - @patch("logging.FileHandler") - @patch("config.Token") - @patch("config.Config") - def test_siem_logger(self, - mock_config, - mock_token, - mock_filehandler, - mock_create_output_handler, - mock_process_endpoint, - mock_create_log_and_state_dir): - # Setup - os.environ['SOPHOS_SIEM_HOME'] = 'fake_sophos_siem_home' - fake_endpoint_config = { - 'format': 'fake_format', - 'filename': 'fake_filename', - 'state_dir': os.path.join(os.environ['SOPHOS_SIEM_HOME'], 'state'), - 'log_dir': os.path.join(os.environ['SOPHOS_SIEM_HOME'], 'log'), - 'since': False + self.LOGGER_MOCK.info.assert_called_with( + u"CEF:0|sophos|sophos central|1.0|NA|NA|0|key=value\n" + ) + + def test_flatten_json(self): + result = siem.flatten_json(1) + self.assertEqual(result, {"": 1}) + + @patch("sys.stderr.write") + def test_log(self, mock_sys_write): + QUIET = False + siem.log("test") + mock_sys_write.assert_called_once() + mock_sys_write.assert_called_with("test\n") + + def test_format_prefix(self): + result = siem.format_prefix("test\\1") + self.assertEqual(result, "test\\\\1") + + def test_format_extension(self): + result = siem.format_extension('"test"') + self.assertEqual(result, '"test"') + result = siem.format_extension({"test": '"test"'}) + self.assertEqual(result, {"test": '"test"'}) + + def test_map_severity(self): + result = siem.map_severity("low") + self.assertEqual(result, 1) + result = siem.map_severity("low_test") + self.assertEqual(result, 0) + + def test_update_cef_keys(self): + same_key_value_data = {"name": "test_name"} + different_key_value_data = {"device_event_class_id": "test_type"} + siem.update_cef_keys(same_key_value_data) + self.assertEqual(same_key_value_data, {"name": "test_name"}) + siem.update_cef_keys(different_key_value_data) + self.assertEqual(different_key_value_data, {"type": "test_type"}) + + def test_format_cef(self): + data = { + "device_event_class_id": "Event::TestEndpoint::TestSuccess", + "severity": "high", + "source": "suser", + "when": "end", } - mock_config.return_value = MagicMock(format=fake_endpoint_config['format'], - filename=fake_endpoint_config['filename']) + result = siem.format_cef(data) + self.assertEqual( + result, + "CEF:0|sophos|sophos central|1.0|NA|NA|8|type=Event::TestEndpoint::TestSuccess suser=suser end=end", + ) + + def test_parse_args_options(self): + options = siem.parse_args_options() + self.assertEqual(options.since, False) + self.assertEqual(options.quiet, False) + self.assertEqual(options.version, False) - # Run - siem.main() - - # Verify - self.assertEqual(self.LOGGER_MOCK.addHandler.call_count, 1) - args, kwargs = mock_create_output_handler.call_args - self.assertEqual(len(args), 1) - self.assertEqual(len(kwargs), 0) - self.assertEqual(args[0], fake_endpoint_config) - - @patch("siem.create_log_and_state_dir") - @patch("siem.process_endpoint") - @patch("siem.create_output_handler") - @patch("logging.FileHandler") - @patch("config.Token") @patch("config.Config") - def test_siem_logger_with_syslog(self, - mock_config, - mock_token, - mock_filehandler, - mock_create_output_handler, - mock_process_endpoint, - mock_create_log_and_state_dir): - # Setup - os.environ['SOPHOS_SIEM_HOME'] = 'fake_sophos_siem_home' + def test_load_config(self, mock_config): + os.environ["SOPHOS_SIEM_HOME"] = "fake_sophos_siem_home" fake_endpoint_config = { - 'format': 'fake_format', - 'filename': 'syslog', - 'state_dir': os.path.join(os.environ['SOPHOS_SIEM_HOME'], 'state'), - 'log_dir': os.path.join(os.environ['SOPHOS_SIEM_HOME'], 'log'), - 'since': False, - 'facility': 'fake_facility', - 'address': 'fake_address', - 'socktype': 'fake_socktype' + "format": "json", + "endpoint": "event", + "filename": "syslog", + "state_file_path": os.path.join( + os.environ["SOPHOS_SIEM_HOME"], "state", "test.json" + ), + "log_dir": os.path.join(os.environ["SOPHOS_SIEM_HOME"], "log"), + "since": False, + "facility": "fake_facility", + "address": "fake_address", + "socktype": "fake_socktype", } - mock_config.return_value = MagicMock(format=fake_endpoint_config['format'], - filename=fake_endpoint_config['filename'], - facility=fake_endpoint_config['facility'], - address=fake_endpoint_config['address'], - socktype=fake_endpoint_config['socktype']) - - # Run - siem.main() - - # Verify - self.assertEqual(self.LOGGER_MOCK.addHandler.call_count, 1) - args, kwargs = mock_create_output_handler.call_args - self.assertEqual(len(args), 1) - self.assertEqual(len(kwargs), 0) - self.assertEqual(args[0], fake_endpoint_config) + mock_config.return_value = MagicMock( + format=fake_endpoint_config["format"], + endpoint=fake_endpoint_config["endpoint"], + filename=fake_endpoint_config["filename"], + state_file_path=os.path.join( + os.environ["SOPHOS_SIEM_HOME"], "state", "test.json" + ), + facility=fake_endpoint_config["facility"], + address=fake_endpoint_config["address"], + socktype=fake_endpoint_config["socktype"], + ) + + config = siem.load_config("test/config.ini") + self.assertEqual(config.format, fake_endpoint_config["format"]) + self.assertEqual( + config.state_file_path, fake_endpoint_config["state_file_path"] + ) + + @patch("siem.write_cef_format") + @patch("siem.write_keyvalue_format") + @patch("siem.write_json_format") + @patch("api_client.ApiClient") + @patch("config.Config") + def test_run( + self, + mock_config, + mock_api_client, + mock_json_format, + mock_keyvalue_format, + mock_cef_format, + ): + mock_api_client.return_value.get_alerts_or_events.return_value = MagicMock([]) + mock_config.return_value = MagicMock(endpoint="event", format="json") + siem.run({}, mock_config, {}) + siem.write_json_format.assert_called_once() + + def test_validate_format(self): + with self.assertRaises(Exception) as context: + siem.validate_format("test") + self.assertTrue( + "Invalid format in config.ini, format can be json, cef or keyvalue" + in str(context.exception) + ) + + def test_validate_endpoint(self): + with self.assertRaises(Exception) as context: + siem.validate_endpoint("test") + self.assertTrue( + "Invalid endpoint in config.ini, endpoint can be event, alert or all" + in str(context.exception) + ) diff --git a/tests/unit/test_state.py b/tests/unit/test_state.py new file mode 100644 index 0000000..a5fc3c1 --- /dev/null +++ b/tests/unit/test_state.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python + +# Copyright 2019-2021 Sophos Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. +# You may obtain a copy of the License at: http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing permissions and limitations under the +# License. +# + +import unittest +import shutil +import os +import mock +import state +import sys +from pathlib import Path + + +class Options: + def __init__(self): + self.quiet = False + + +class TestState(unittest.TestCase): + """Test State file items are exposed as attributes on state object""" + + def setUp(self): + options = Options() + self.state = state.State(options, "/tmp/state/test_siem_sophos.json") + + def tearDown(self): + if os.path.exists(self.state.state_file): + state_dir = os.path.dirname(self.state.state_file) + shutil.rmtree(state_dir) + + def test_init(self): + path = Path("/tmp/state/") + self.assertEqual(self.state.state_file, "/tmp/state/test_siem_sophos.json") + self.assertEquals(path.parent.is_dir(), True) + self.assertEqual(self.state.state_data, {}) + + @mock.patch("state.State.get_state_file") + @mock.patch("state.State.create_state_dir") + @mock.patch("state.State.load_state_file") + @mock.patch("sys.stderr.write") + def test_log(self, mock_sys_write, mock_load_file, mock_state_dir, mock_state_file): + self.state.log("test") + mock_sys_write.assert_called_once() + mock_sys_write.assert_called_with("test\n") + + @mock.patch("state.State.create_state_dir") + @mock.patch("state.State.load_state_file") + def test_get_state_file_with_empty_state_path( + self, mock_load_file, mock_state_dir + ): + filepath = self.state.get_state_file("/tmp", None) + self.assertEqual(filepath, "/tmp/state/siem_sophos.json") + + @mock.patch("state.State.get_state_file") + @mock.patch("state.State.create_state_dir") + @mock.patch("sys.stderr.write") + def test_load_state_file_io_exception( + self, mock_sys_write, mock_load_file, mock_state_dir + ): + self.state.state_file = "/tmp/test.json" + self.state.load_state_file() + mock_sys_write.assert_called_with("Sophos state file not found\n") + + def test_save_state(self): + self.state.save_state("test.test_account", "test_account") + self.state.load_state_file() + self.assertEqual( + self.state.state_data, {"test": {"test_account": "test_account"}} + )