diff --git a/source/app/blueprints/case/case_ioc_routes.py b/source/app/blueprints/case/case_ioc_routes.py index d2ee4cb7f..58318ce7d 100644 --- a/source/app/blueprints/case/case_ioc_routes.py +++ b/source/app/blueprints/case/case_ioc_routes.py @@ -64,6 +64,8 @@ from app.util import ac_case_requires from app.util import response_error from app.util import response_success +from app.business.iocs import create +from app.business.errors import BusinessProcessingError case_ioc_blueprint = Blueprint( 'case_ioc', @@ -126,40 +128,13 @@ def case_ioc_state(caseid): @case_ioc_blueprint.route('/case/ioc/add', methods=['POST']) @ac_api_case_requires(CaseAccessLevel.full_access) def case_add_ioc(caseid): - try: - # validate before saving - add_ioc_schema = IocSchema() - - request_data = call_modules_hook('on_preload_ioc_create', data=request.get_json(), caseid=caseid) - - ioc = add_ioc_schema.load(request_data) - - if not check_ioc_type_id(type_id=ioc.ioc_type_id): - return response_error("Not a valid IOC type") - - ioc, existed = add_ioc(ioc=ioc, - user_id=current_user.id, - caseid=caseid - ) - link_existed = add_ioc_link(ioc.ioc_id, caseid) - - if link_existed: - return response_success("IOC already exists and linked to this case", data=add_ioc_schema.dump(ioc)) + add_ioc_schema = IocSchema() - if not link_existed: - ioc = call_modules_hook('on_postload_ioc_create', data=ioc, caseid=caseid) - - if ioc: - track_activity("added ioc \"{}\"".format(ioc.ioc_value), caseid=caseid) - - msg = "IOC already existed in DB. Updated with info on DB." if existed else "IOC added" - - return response_success(msg=msg, data=add_ioc_schema.dump(ioc)) - - return response_error("Unable to create IOC for internal reasons") - - except marshmallow.exceptions.ValidationError as e: - return response_error(msg="Data error", data=e.messages, status=400) + try: + ioc, msg = create(request.get_json(), caseid) + return response_success(msg, data=add_ioc_schema.dump(ioc)) + except BusinessProcessingError as e: + return response_error(e.get_message(), data=e.get_data()) @case_ioc_blueprint.route('/case/ioc/upload', methods=['POST']) diff --git a/source/app/blueprints/graphql/__init__.py b/source/app/blueprints/graphql/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/source/app/blueprints/graphql/cases.py b/source/app/blueprints/graphql/cases.py new file mode 100644 index 000000000..d7d4a7064 --- /dev/null +++ b/source/app/blueprints/graphql/cases.py @@ -0,0 +1,28 @@ +# IRIS Source Code +# Copyright (C) 2024 - DFIR-IRIS +# contact@dfir-iris.org +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 3 of the License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program; if not, write to the Free Software Foundation, +# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + +from graphene_sqlalchemy import SQLAlchemyObjectType +from graphene.relay import Node + +from app.models.cases import Cases + + +class CaseObject(SQLAlchemyObjectType): + class Meta: + model = Cases + interfaces = [Node] diff --git a/source/app/blueprints/graphql/graphql_route.py b/source/app/blueprints/graphql/graphql_route.py new file mode 100644 index 000000000..98bf03366 --- /dev/null +++ b/source/app/blueprints/graphql/graphql_route.py @@ -0,0 +1,85 @@ +# IRIS Source Code +# Copyright (C) 2024 - DFIR-IRIS +# contact@dfir-iris.org +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 3 of the License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program; if not, write to the Free Software Foundation, +# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + +from functools import wraps +from flask import request +from flask_wtf import FlaskForm +from flask import Blueprint +from flask_login import current_user + +from graphql_server.flask import GraphQLView +from graphene import ObjectType +from graphene import Schema +from graphene import List + +from app.util import is_user_authenticated +from app.util import response_error +from app.datamgmt.manage.manage_cases_db import get_filtered_cases +from app.blueprints.graphql.cases import CaseObject +from app.blueprints.graphql.iocs import AddIoc + + +class Query(ObjectType): + """This is the IRIS GraphQL queries documentation!""" + + # starting with the conversion of '/manage/cases/filter' + cases = List(CaseObject, description='Retrieves cases') + + @staticmethod + def resolve_cases(root, info): + # TODO add all parameters to filter + return get_filtered_cases(current_user.id) + + +class Mutation(ObjectType): + create_ioc = AddIoc.Field() + + +def _check_authentication_wrapper(f): + @wraps(f) + def wrap(*args, **kwargs): + if request.method == 'POST': + cookie_session = request.cookies.get('session') + if cookie_session: + form = FlaskForm() + if not form.validate(): + return response_error('Invalid CSRF token') + elif request.is_json: + request.json.pop('csrf_token') + + if not is_user_authenticated(request): + return response_error('Authentication required', status=401) + + return f(*args, **kwargs) + return wrap + + +def _create_blueprint(): + schema = Schema(query=Query, mutation=Mutation) + graphql_view = GraphQLView.as_view('graphql', schema=schema) + graphql_view_with_authentication = _check_authentication_wrapper(graphql_view) + + blueprint = Blueprint('graphql', __name__) + blueprint.add_url_rule('/graphql', view_func=graphql_view_with_authentication, methods=['POST']) + + return blueprint + + +graphql_blueprint = _create_blueprint() + +# TODO I am unsure about the code organization (directories) diff --git a/source/app/blueprints/graphql/iocs.py b/source/app/blueprints/graphql/iocs.py new file mode 100644 index 000000000..4f1aeb9cd --- /dev/null +++ b/source/app/blueprints/graphql/iocs.py @@ -0,0 +1,60 @@ +# IRIS Source Code +# Copyright (C) 2024 - DFIR-IRIS +# contact@dfir-iris.org +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 3 of the License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program; if not, write to the Free Software Foundation, +# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + +from graphene_sqlalchemy import SQLAlchemyObjectType +from graphene import Field +from graphene import Mutation +from graphene import NonNull +from graphene import Int +from graphene import String + +from app.models.models import Ioc +from app.business.iocs import create + + +class IocObject(SQLAlchemyObjectType): + class Meta: + model = Ioc + + +class AddIoc(Mutation): + + class Arguments: + # TODO: it seems really too difficult to work with IDs. + # I don't understand why graphql_relay.from_global_id does not seem to work... + # note: I prefer NonNull rather than the syntax required=True + # TODO: Integers in graphql are only 32 bits. => will this be a problem? Should we use either float or string? + case_id = NonNull(Int) + type_id = NonNull(Int) + tlp_id = NonNull(Int) + value = NonNull(String) + # TODO add these non mandatory arguments + #description = + #tags = + + ioc = Field(IocObject) + + @staticmethod + def mutate(root, info, case_id, type_id, tlp_id, value): + request = { + 'ioc_type_id': type_id, + 'ioc_tlp_id': tlp_id, + 'ioc_value': value + } + ioc, _ = create(request, case_id) + return AddIoc(ioc=ioc) diff --git a/source/app/blueprints/manage/manage_cases_routes.py b/source/app/blueprints/manage/manage_cases_routes.py index e14dd9a1e..c94c3e398 100644 --- a/source/app/blueprints/manage/manage_cases_routes.py +++ b/source/app/blueprints/manage/manage_cases_routes.py @@ -33,7 +33,6 @@ from werkzeug import Response from werkzeug.utils import secure_filename -import app from app import db from app.datamgmt.alerts.alerts_db import get_alert_status_by_name from app.datamgmt.case.case_db import get_case, get_review_id_from_name @@ -49,17 +48,14 @@ from app.datamgmt.manage.manage_case_templates_db import get_case_templates_list, case_template_pre_modifier, \ case_template_post_modifier from app.datamgmt.manage.manage_cases_db import close_case, map_alert_resolution_to_case_status, get_filtered_cases -from app.datamgmt.manage.manage_cases_db import delete_case from app.datamgmt.manage.manage_cases_db import get_case_details_rt from app.datamgmt.manage.manage_cases_db import get_case_protagonists from app.datamgmt.manage.manage_cases_db import list_cases_dict from app.datamgmt.manage.manage_cases_db import reopen_case from app.datamgmt.manage.manage_common import get_severities_list -from app.datamgmt.manage.manage_users_db import get_user_organisations from app.forms import AddCaseForm -from app.iris_engine.access_control.utils import ac_fast_check_current_user_has_case_access, \ - ac_current_user_has_permission -from app.iris_engine.access_control.utils import ac_fast_check_user_has_case_access +from app.iris_engine.access_control.utils import ac_fast_check_current_user_has_case_access +from app.iris_engine.access_control.utils import ac_current_user_has_permission from app.iris_engine.access_control.utils import ac_set_new_case_access from app.iris_engine.module_handler.module_handler import call_modules_hook from app.iris_engine.module_handler.module_handler import configure_module_on_init @@ -67,18 +63,19 @@ from app.iris_engine.tasker.tasks import task_case_update from app.iris_engine.utils.common import build_upload_path from app.iris_engine.utils.tracker import track_activity -from app.models.alerts import AlertStatus from app.models.authorization import CaseAccessLevel from app.models.authorization import Permissions -from app.models.models import Client, ReviewStatusList +from app.models.models import ReviewStatusList from app.schema.marshables import CaseSchema, CaseDetailsSchema -from app.util import ac_api_case_requires, add_obj_history_entry +from app.util import add_obj_history_entry from app.util import ac_api_requires from app.util import ac_api_return_access_denied -from app.util import ac_case_requires from app.util import ac_requires from app.util import response_error from app.util import response_success +from app.business.cases import delete +from app.business.errors import BusinessProcessingError +from app.business.errors import PermissionDeniedError manage_cases_blueprint = Blueprint('manage_case', __name__, @@ -201,6 +198,7 @@ def manage_case_filter(caseid) -> Response: draw = 1 filtered_cases = get_filtered_cases( + current_user.id, case_ids=case_ids_str, case_customer_id=case_customer_id, case_name=case_name, @@ -216,7 +214,6 @@ def manage_case_filter(caseid) -> Response: search_value=search_value, page=page, per_page=per_page, - current_user_id=current_user.id, sort_by=order_by, sort_dir=sort_dir ) @@ -238,35 +235,14 @@ def manage_case_filter(caseid) -> Response: @manage_cases_blueprint.route('/manage/cases/delete/', methods=['POST']) @ac_api_requires(Permissions.standard_user, no_cid_required=True) def api_delete_case(cur_id, caseid): - if not ac_fast_check_current_user_has_case_access(cur_id, [CaseAccessLevel.full_access]): + try: + delete(cur_id, caseid) + return response_success('Case successfully deleted') + except BusinessProcessingError as e: + return response_error(e.get_message()) + except PermissionDeniedError: return ac_api_return_access_denied(caseid=cur_id) - if cur_id == 1: - track_activity("tried to delete case {}, but case is the primary case".format(cur_id), - caseid=caseid, ctx_less=True) - - return response_error("Cannot delete a primary case to keep consistency") - - else: - try: - call_modules_hook('on_preload_case_delete', data=cur_id, caseid=caseid) - if delete_case(case_id=cur_id): - - call_modules_hook('on_postload_case_delete', data=cur_id, caseid=caseid) - - track_activity("case {} deleted successfully".format(cur_id), ctx_less=True) - return response_success("Case successfully deleted") - - else: - track_activity("tried to delete case {}, but it doesn't exist".format(cur_id), - caseid=caseid, ctx_less=True) - - return response_error("Tried to delete a non-existing case") - - except Exception as e: - app.app.logger.exception(e) - return response_error("Cannot delete the case. Please check server logs for additional informations") - @manage_cases_blueprint.route('/manage/cases/reopen/', methods=['POST']) @ac_api_requires(Permissions.standard_user, no_cid_required=True) diff --git a/source/app/business/__init__.py b/source/app/business/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/source/app/business/cases.py b/source/app/business/cases.py new file mode 100644 index 000000000..9e5f0a230 --- /dev/null +++ b/source/app/business/cases.py @@ -0,0 +1,50 @@ +# IRIS Source Code +# Copyright (C) 2024 - DFIR-IRIS +# contact@dfir-iris.org +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 3 of the License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program; if not, write to the Free Software Foundation, +# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + +from app import app +from app.models.authorization import CaseAccessLevel +from app.models.authorization import Permissions +from app.iris_engine.module_handler.module_handler import call_modules_hook +from app.iris_engine.utils.tracker import track_activity +from app.datamgmt.manage.manage_cases_db import delete_case +from app.business.errors import BusinessProcessingError +from app.business.permissions import check_current_user_has_some_case_access +from app.business.permissions import check_current_user_has_some_permission + + +def delete(case_identifier, context_case_identifier): + check_current_user_has_some_permission([Permissions.standard_user]) + check_current_user_has_some_case_access(case_identifier, [CaseAccessLevel.full_access]) + + if case_identifier == 1: + track_activity(f'tried to delete case {case_identifier}, but case is the primary case', + caseid=context_case_identifier, ctx_less=True) + + raise BusinessProcessingError('Cannot delete a primary case to keep consistency') + + try: + call_modules_hook('on_preload_case_delete', data=case_identifier, caseid=context_case_identifier) + if not delete_case(case_identifier): + track_activity(f'tried to delete case {case_identifier}, but it doesn\'t exist', + caseid=context_case_identifier, ctx_less=True) + raise BusinessProcessingError('Tried to delete a non-existing case') + call_modules_hook('on_postload_case_delete', data=case_identifier, caseid=context_case_identifier) + track_activity(f'case {case_identifier} deleted successfully', ctx_less=True) + except Exception as e: + app.logger.exception(e) + raise BusinessProcessingError('Cannot delete the case. Please check server logs for additional informations') diff --git a/source/app/business/errors.py b/source/app/business/errors.py new file mode 100644 index 000000000..15efb8158 --- /dev/null +++ b/source/app/business/errors.py @@ -0,0 +1,34 @@ +# IRIS Source Code +# Copyright (C) 2024 - DFIR-IRIS +# contact@dfir-iris.org +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 3 of the License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program; if not, write to the Free Software Foundation, +# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + + +class BusinessProcessingError(Exception): + + def __init__(self, message, data=None): + self._message = message + self._data = data + + def get_message(self): + return self._message + + def get_data(self): + return self._data + + +class PermissionDeniedError(Exception): + pass diff --git a/source/app/business/iocs.py b/source/app/business/iocs.py new file mode 100644 index 000000000..98d763482 --- /dev/null +++ b/source/app/business/iocs.py @@ -0,0 +1,70 @@ +# IRIS Source Code +# Copyright (C) 2024 - DFIR-IRIS +# contact@dfir-iris.org +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 3 of the License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program; if not, write to the Free Software Foundation, +# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + +from flask_login import current_user +from marshmallow.exceptions import ValidationError + +from app.models.authorization import CaseAccessLevel +from app.datamgmt.case.case_iocs_db import add_ioc +from app.datamgmt.case.case_iocs_db import add_ioc_link +from app.datamgmt.case.case_iocs_db import check_ioc_type_id +from app.schema.marshables import IocSchema +from app.iris_engine.module_handler.module_handler import call_modules_hook +from app.iris_engine.utils.tracker import track_activity +from app.business.errors import BusinessProcessingError +from app.business.permissions import check_current_user_has_some_case_access_stricter + + +def _load(request_data): + try: + add_ioc_schema = IocSchema() + return add_ioc_schema.load(request_data) + except ValidationError as e: + raise BusinessProcessingError('Data error', e.messages) + + +def create(request_json, case_identifier): + check_current_user_has_some_case_access_stricter([CaseAccessLevel.full_access]) + + # TODO ideally schema validation should be done before, outside the business logic in the REST API + # for that the hook should be called after schema validation + request_data = call_modules_hook('on_preload_ioc_create', data=request_json, caseid=case_identifier) + ioc = _load(request_data) + + if not check_ioc_type_id(type_id=ioc.ioc_type_id): + raise BusinessProcessingError('Not a valid IOC type') + + ioc, existed = add_ioc(ioc=ioc, user_id=current_user.id, caseid=case_identifier) + + link_existed = add_ioc_link(ioc.ioc_id, case_identifier) + + if link_existed: + # note: I am no big fan of returning tuples. + # It is a code smell some type is missing, or the code is badly designed. + return ioc, 'IOC already exists and linked to this case' + + if not link_existed: + ioc = call_modules_hook('on_postload_ioc_create', data=ioc, caseid=case_identifier) + + if ioc: + track_activity(f'added ioc "{ioc.ioc_value}"', caseid=case_identifier) + + msg = "IOC already existed in DB. Updated with info on DB." if existed else "IOC added" + return ioc, msg + + raise BusinessProcessingError('Unable to create IOC for internal reasons') diff --git a/source/app/business/permissions.py b/source/app/business/permissions.py new file mode 100644 index 000000000..651ab2f71 --- /dev/null +++ b/source/app/business/permissions.py @@ -0,0 +1,55 @@ +# IRIS Source Code +# Copyright (C) 2024 - DFIR-IRIS +# contact@dfir-iris.org +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 3 of the License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program; if not, write to the Free Software Foundation, +# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + +from flask import session +from flask_login import current_user +from flask import request + +from app.util import get_case_access +from app.iris_engine.access_control.utils import ac_get_effective_permissions_of_user +from app.iris_engine.access_control.utils import ac_fast_check_current_user_has_case_access +from app.business.errors import PermissionDeniedError + + +def check_current_user_has_some_case_access(case_identifier, access_levels): + if not ac_fast_check_current_user_has_case_access(case_identifier, access_levels): + raise PermissionDeniedError() + + +# TODO: really this and the previous method should be merged. +# This one comes from ac_api_case_requires, whereas the other one comes from the way api_delete_case was written... +def check_current_user_has_some_case_access_stricter(access_levels): + redir, caseid, has_access = get_case_access(request, access_levels, from_api=True) + + # TODO: do we really want to keep the details of the errors, when permission is denied => more work, more complex code? + if not caseid or redir: + raise PermissionDeniedError() + + if not has_access: + raise PermissionDeniedError() + + +def check_current_user_has_some_permission(permissions): + if 'permissions' not in session: + session['permissions'] = ac_get_effective_permissions_of_user(current_user) + + for permission in permissions: + if session['permissions'] & permission.value: + return + + raise PermissionDeniedError() diff --git a/source/app/datamgmt/manage/manage_cases_db.py b/source/app/datamgmt/manage/manage_cases_db.py index a0716baed..87c0c1f03 100644 --- a/source/app/datamgmt/manage/manage_cases_db.py +++ b/source/app/datamgmt/manage/manage_cases_db.py @@ -381,7 +381,8 @@ def delete_case(case_id): return True -def get_filtered_cases(start_open_date: str = None, +def get_filtered_cases(current_user_id, + start_open_date: str = None, end_open_date: str = None, case_customer_id: int = None, case_ids: list = None, @@ -395,7 +396,6 @@ def get_filtered_cases(start_open_date: str = None, case_soc_id: str = None, per_page: int = None, page: int = None, - current_user_id = None, search_value=None, sort_by=None, sort_dir='asc' diff --git a/source/app/views.py b/source/app/views.py index bfaf836e9..6bf1a0225 100644 --- a/source/app/views.py +++ b/source/app/views.py @@ -30,6 +30,7 @@ from app.blueprints.case.case_routes import case_blueprint from app.blueprints.context.context import ctx_blueprint # Blueprints +from app.blueprints.graphql.graphql_route import graphql_blueprint from app.blueprints.dashboard.dashboard_routes import dashboard_blueprint from app.blueprints.datastore.datastore_routes import datastore_blueprint from app.blueprints.demo_landing.demo_landing import demo_blueprint @@ -67,6 +68,8 @@ from app.models.authorization import User from app.post_init import run_post_init + +app.register_blueprint(graphql_blueprint) app.register_blueprint(dashboard_blueprint) app.register_blueprint(overview_blueprint) app.register_blueprint(login_blueprint) diff --git a/source/requirements.txt b/source/requirements.txt index b80f631c1..772a01057 100644 --- a/source/requirements.txt +++ b/source/requirements.txt @@ -32,6 +32,10 @@ PyJWT==2.4.0 cryptography>=39.0.1 ldap3==2.9.1 pyintelowl>=4.4.0 +graphene==3.3 +# unfortunately we are relying on a beta version here. I hope a definitive version gets released soon +graphql-server[flask]==3.0.0b7 +graphene-sqlalchemy==3.0.0rc1 dependencies/docx_generator-0.8.0-py3-none-any.whl dependencies/iris_interface-1.2.0-py3-none-any.whl diff --git a/tests/docker_compose.py b/tests/docker_compose.py index dd6223f07..c30d2def1 100644 --- a/tests/docker_compose.py +++ b/tests/docker_compose.py @@ -18,6 +18,8 @@ import subprocess +_DOCKER_COMPOSE = ['docker', 'compose'] + class DockerCompose: @@ -25,10 +27,10 @@ def __init__(self, docker_compose_path): self._docker_compose_path = docker_compose_path def start(self): - subprocess.check_call(['docker', 'compose', 'up', '--detach'], cwd=self._docker_compose_path) + subprocess.check_call(_DOCKER_COMPOSE + ['up', '--detach'], cwd=self._docker_compose_path) def extract_all_logs(self): - return subprocess.check_output(['docker', 'compose', 'logs', '--no-color'], cwd=self._docker_compose_path, universal_newlines=True) + return subprocess.check_output(_DOCKER_COMPOSE + ['logs', '--no-color'], cwd=self._docker_compose_path, universal_newlines=True) def stop(self): - subprocess.check_call(['docker', 'compose', 'down', '--volumes'], cwd=self._docker_compose_path) + subprocess.check_call(_DOCKER_COMPOSE + ['down', '--volumes'], cwd=self._docker_compose_path) diff --git a/tests/graphql_api.py b/tests/graphql_api.py new file mode 100644 index 000000000..98dc5864b --- /dev/null +++ b/tests/graphql_api.py @@ -0,0 +1,29 @@ +# IRIS Source Code +# Copyright (C) 2023 - DFIR-IRIS +# contact@dfir-iris.org +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 3 of the License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program; if not, write to the Free Software Foundation, +# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + +import requests + + +class GraphQLApi: + + def __init__(self, url, api_key): + self._url = url + self._headers = {'Authorization': f'Bearer {api_key}', 'Content-Type': 'application/json'} + + def execute(self, payload): + return requests.post(self._url, headers=self._headers, json=payload) diff --git a/tests/iris.py b/tests/iris.py index 2bc2f9da6..7f415bff6 100644 --- a/tests/iris.py +++ b/tests/iris.py @@ -21,9 +21,10 @@ import time from docker_compose import DockerCompose from rest_api import RestApi +from graphql_api import GraphQLApi from server_timeout_error import ServerTimeoutError -_API_URL = 'http://127.0.0.1:8000' +API_URL = 'http://127.0.0.1:8000' _API_KEY = 'B8BA5D730210B50F41C06941582D7965D57319D5685440587F98DFDC45A01594' _IRIS_PATH = Path('..') _TEST_DATA_PATH = Path('./data') @@ -33,7 +34,8 @@ class Iris: def __init__(self): self._docker_compose = DockerCompose(_IRIS_PATH) - self._api = RestApi(_API_URL, _API_KEY) + self._api = RestApi(API_URL, _API_KEY) + self._graphql_api = GraphQLApi(API_URL + '/graphql', _API_KEY) def _wait(self, condition, attempts, sleep_duration=1): count = 0 @@ -89,3 +91,7 @@ def update_case(self, case_identifier, data): def get_cases(self): return self._api.get('/manage/cases/list') + + def execute_graphql_query(self, payload): + response = self._graphql_api.execute(payload) + return response.json() diff --git a/tests/tests.py b/tests/tests.py index f88b1f7a0..66e6ed971 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -18,6 +18,9 @@ from unittest import TestCase from iris import Iris +from iris import API_URL +from graphql_api import GraphQLApi +from base64 import b64encode class Tests(TestCase): @@ -54,3 +57,55 @@ def test_update_case_should_not_require_case_name_issue_358(self): case_identifier = response['data']['case_id'] response = self._subject.update_case(case_identifier, {'case_tags': 'test,example'}) self.assertEqual('success', response['status']) + + def test_graphql_endpoint_should_reject_requests_with_wrong_authentication_token(self): + graphql_api = GraphQLApi(API_URL + '/graphql', 64*'0') + payload = { + 'query': '{ cases { name } }' + } + response = graphql_api.execute(payload) + self.assertEqual(401, response.status_code) + + def test_graphql_cases_should_contain_the_initial_case(self): + payload = { + 'query': '{ cases { name } }' + } + body = self._subject.execute_graphql_query(payload) + case_names = [] + for case in body['data']['cases']: + case_names.append(case['name']) + self.assertIn('#1 - Initial Demo', case_names) + + def test_graphql_cases_should_have_a_global_identifier(self): + payload = { + 'query': '{ cases { id name } }' + } + body = self._subject.execute_graphql_query(payload) + first_case = self._get_first_case(body) + self.assertEqual(b64encode(b'CaseObject:1').decode(), first_case['id']) + + def test_graphql_create_ioc_should_not_fail(self): + payload = { + 'query': f'''mutation {{ + createIoc(caseId: 1, typeId: 1, tlpId: 1, value: "8.8.8.8", + description: "some description", tags: "") {{ + ioc {{ iocValue }} + }} + }}''' + } + payload = { + 'query': f'''mutation {{ + createIoc(caseId: 1, typeId: 1, tlpId: 1, value: "8.8.8.8") {{ + ioc {{ iocValue }} + }} + }}''' + } + body = self._subject.execute_graphql_query(payload) + self.assertNotIn('errors', body) + + def _get_first_case(self, body): + for case in body['data']['cases']: + if case['name'] == '#1 - Initial Demo': + return case + +# TODO: should maybe try to use gql \ No newline at end of file