diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 21ae27b14..496966f01 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -33,22 +33,174 @@ jobs: args: check --output-format=github src: ./source - tests: - name: API tests + build-docker-db: + name: Build docker db runs-on: ubuntu-22.04 steps: - name: Check out iris uses: actions/checkout@v4 - - name: Build dockers + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - name: Build and export + uses: docker/build-push-action@v6 + with: + context: docker/db + tags: iriswebapp_db:develop + outputs: type=docker,dest=${{ runner.temp }}/iriswebapp_db.tar + cache-from: type=gha + cache-to: type=gha,mode=max + - name: Upload artifact + uses: actions/upload-artifact@v4 + with: + name: iriswebapp_db + path: ${{ runner.temp }}/iriswebapp_db.tar + + build-docker-nginx: + name: Build docker nginx + runs-on: ubuntu-22.04 + steps: + - name: Check out iris + uses: actions/checkout@v4 + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - name: Build and export + uses: docker/build-push-action@v6 + with: + context: docker/nginx + tags: iriswebapp_nginx:develop + build-args: | + NGINX_CONF_GID=1234 + NGINX_CONF_FILE=nginx.conf + outputs: type=docker,dest=${{ runner.temp }}/iriswebapp_nginx.tar + cache-from: type=gha + cache-to: type=gha,mode=max + - name: Upload artifact + uses: actions/upload-artifact@v4 + with: + name: iriswebapp_nginx + path: ${{ runner.temp }}/iriswebapp_nginx.tar + + build-docker-app: + name: Build docker app + runs-on: ubuntu-22.04 + steps: + - name: Check out iris + uses: actions/checkout@v4 + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - name: Build and export + uses: docker/build-push-action@v6 + with: + context: . + file: docker/webApp/Dockerfile + tags: iriswebapp_app:develop + outputs: type=docker,dest=${{ runner.temp }}/iriswebapp_app.tar + cache-from: type=gha + cache-to: type=gha,mode=max + - name: Upload artifact + uses: actions/upload-artifact@v4 + with: + name: iriswebapp_app + path: ${{ runner.temp }}/iriswebapp_app.tar + + build-graphql-documentation: + name: Generate graphQL documentation + runs-on: ubuntu-22.04 + needs: + - build-docker-db + - build-docker-nginx + - build-docker-app + steps: + - name: Download artifacts + uses: actions/download-artifact@v4 + with: + pattern: iriswebapp_* + path: ${{ runner.temp }} + merge-multiple: true + - name: Load docker images + run: | + docker load --input ${{ runner.temp }}/iriswebapp_db.tar + docker load --input ${{ runner.temp }}/iriswebapp_nginx.tar + docker load --input ${{ runner.temp }}/iriswebapp_app.tar + - name: Check out iris + uses: actions/checkout@v4 + - name: Start development server run: | - # TODO using the environment file from tests to build here. - # I am a bit uneasy with this choice. - # For now this works, but if we come to have different .env files for different tests, it won't anymore. - # Maybe the .env should be split to differentiate the variables used during the build from the variables used at runtime, - # or maybe the docker building phase should also be part of the tests - # and we should build different dockers according to the scenarios? This sounds like an issue to me... + # Even though, we use --env-file option when running docker compose, this is still necessary, because the compose has a env_file attribute :( + # TODO should move basic.env file, which is in directory tests, up. It's used in several places. Maybe, rename it into dev.env cp tests/data/basic.env .env - docker compose --file docker-compose.dev.yml build + docker compose --file docker-compose.dev.yml --env-file tests/data/basic.env up --detach + - name: Generate GraphQL documentation + run: | + npx spectaql@^3.0.2 source/spectaql/config.yml + - name: Stop development server + run: | + docker compose down + - uses: actions/upload-artifact@v4 + with: + name: GraphQL DFIR-IRIS documentation + path: public + if-no-files-found: error + + test-api: + name: Test API + runs-on: ubuntu-22.04 + needs: + - build-docker-db + - build-docker-nginx + - build-docker-app + steps: + - name: Download artifacts + uses: actions/download-artifact@v4 + with: + pattern: iriswebapp_* + path: ${{ runner.temp }} + merge-multiple: true + - name: Load docker images + run: | + docker load --input ${{ runner.temp }}/iriswebapp_db.tar + docker load --input ${{ runner.temp }}/iriswebapp_nginx.tar + docker load --input ${{ runner.temp }}/iriswebapp_app.tar + - name: Check out iris + uses: actions/checkout@v4 + - name: Start development server + run: | + # Even though, we use --env-file option when running docker compose, this is still necessary, because the compose has a env_file attribute :( + # TODO should move basic.env file, which is in directory tests, up. It's used in several places. Maybe, rename it into dev.env + cp tests/data/basic.env .env + docker compose --file docker-compose.dev.yml up --detach + - name: Run tests + working-directory: tests + run: | + python -m venv venv + source venv/bin/activate + pip install -r requirements.txt + PYTHONUNBUFFERED=true python -m unittest --verbose + - name: Stop development server + run: | + docker compose down + + test-e2e: + name: End to end tests + runs-on: ubuntu-22.04 + needs: + - build-docker-db + - build-docker-nginx + - build-docker-app + steps: + - name: Download artifacts + uses: actions/download-artifact@v4 + with: + pattern: iriswebapp_* + path: ${{ runner.temp }} + merge-multiple: true + - name: Load docker images + run: | + docker load --input ${{ runner.temp }}/iriswebapp_db.tar + docker load --input ${{ runner.temp }}/iriswebapp_nginx.tar + docker load --input ${{ runner.temp }}/iriswebapp_app.tar + - name: Check out iris + uses: actions/checkout@v4 - uses: actions/setup-node@v4 with: node-version: 20 @@ -61,30 +213,20 @@ jobs: run: | npm ci npm run build - - name: Run tests - working-directory: tests - run: | - python -m venv venv - source venv/bin/activate - pip install -r requirements.txt - docker compose --file ../docker-compose.dev.yml --env-file data/basic.env up --detach --wait - PYTHONUNBUFFERED=true python -m unittest --verbose - docker compose down - - name: Start development server - run: | - docker compose --file docker-compose.dev.yml up --detach - name: Install e2e dependencies working-directory: e2e run: npm ci - name: Install playwright dependencies working-directory: e2e run: npx playwright install chromium firefox + - name: Start development server + run: | + # TODO should move basic.env file, which is in directory tests, up. It's used in several places. Maybe, rename it into dev.env + cp tests/data/basic.env .env + docker compose --file docker-compose.dev.yml up --detach - name: Run end to end tests working-directory: e2e run: npx playwright test - - name: Generate GraphQL documentation - run: | - npx spectaql@^3.0.2 source/spectaql/config.yml - name: Stop development server run: | docker compose down @@ -93,8 +235,4 @@ jobs: with: name: playwright-report path: e2e/playwright-report/ - - uses: actions/upload-artifact@v4 - with: - name: GraphQL DFIR-IRIS documentation - path: public - if-no-files-found: error + diff --git a/source/app/blueprints/rest/case/case_tasks_routes.py b/source/app/blueprints/rest/case/case_tasks_routes.py index 2ff091f0c..48e7a6e64 100644 --- a/source/app/blueprints/rest/case/case_tasks_routes.py +++ b/source/app/blueprints/rest/case/case_tasks_routes.py @@ -54,6 +54,7 @@ @case_tasks_rest_blueprint.route('/case/tasks/list', methods=['GET']) +@endpoint_deprecated('GET', '/api/v2/cases//tasks') @ac_requires_case_identifier(CaseAccessLevel.read_only, CaseAccessLevel.full_access) @ac_api_requires() def case_get_tasks(caseid): diff --git a/source/app/blueprints/rest/endpoints.py b/source/app/blueprints/rest/endpoints.py index de35479dc..8cbb394cd 100644 --- a/source/app/blueprints/rest/endpoints.py +++ b/source/app/blueprints/rest/endpoints.py @@ -17,6 +17,7 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. from functools import wraps +from flask_sqlalchemy.pagination import Pagination from app import app from app.blueprints.responses import response_error, response @@ -28,6 +29,28 @@ def response_api_success(data): return response(200, data=data) +def _get_next_page(paginated_elements: Pagination): + if paginated_elements.has_next: + next_page = paginated_elements.has_next + else: + next_page = None + return next_page + + +def response_api_paginated(schema, paginated_elements: Pagination): + data = schema.dump(paginated_elements.items, many=True) + next_page = _get_next_page(paginated_elements) + + result = { + 'total': paginated_elements.total, + 'data': data, + 'last_page': paginated_elements.pages, + 'current_page': paginated_elements.page, + 'next_page': next_page + } + + return response_api_success(result) + def response_api_deleted(): return response(204) diff --git a/source/app/blueprints/rest/manage/manage_cases_routes.py b/source/app/blueprints/rest/manage/manage_cases_routes.py index a6871c77f..c9fb341e6 100644 --- a/source/app/blueprints/rest/manage/manage_cases_routes.py +++ b/source/app/blueprints/rest/manage/manage_cases_routes.py @@ -55,6 +55,7 @@ from app.blueprints.access_controls import ac_api_return_access_denied from app.blueprints.responses import response_error from app.blueprints.responses import response_success +from app.blueprints.rest.parsing import parse_pagination_parameters from app.business.cases import cases_delete from app.business.cases import cases_update from app.business.cases import cases_create @@ -81,11 +82,9 @@ def get_case_api(cur_id): @ac_api_requires() def manage_case_filter() -> Response: - page = request.args.get('page', 1, type=int) - per_page = request.args.get('per_page', 10, type=int) + pagination_parameters = parse_pagination_parameters(request) + case_ids_str = request.args.get('case_ids', None, type=str) - order_by = request.args.get('order_by', type=str) - sort_dir = request.args.get('sort_dir', 'asc', type=str) if case_ids_str: try: @@ -112,6 +111,8 @@ def manage_case_filter() -> Response: draw = 1 filtered_cases = get_filtered_cases( + current_user.id, + pagination_parameters, case_ids=case_ids_str, case_customer_id=case_customer_id, case_name=case_name, @@ -124,12 +125,7 @@ def manage_case_filter() -> Response: case_soc_id=case_soc_id, start_open_date=start_open_date, end_open_date=end_open_date, - search_value=search_value, - page=page, - per_page=per_page, - current_user_id=current_user.id, - sort_by=order_by, - sort_dir=sort_dir + search_value=search_value ) if filtered_cases is None: return response_error('Filtering error') diff --git a/source/app/blueprints/rest/parsing.py b/source/app/blueprints/rest/parsing.py index d84907643..25cdd8e7f 100644 --- a/source/app/blueprints/rest/parsing.py +++ b/source/app/blueprints/rest/parsing.py @@ -16,6 +16,8 @@ # along with this program; if not, write to the Free Software Foundation, # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +from app.models.pagination_parameters import PaginationParameters + def parse_comma_separated_identifiers(identifiers: str): return [int(identifier) for identifier in identifiers.split(',')] @@ -27,3 +29,13 @@ def parse_boolean(parameter: str): if parameter == 'false': return False raise ValueError(f'Expected true or false, got {parameter}') + + +def parse_pagination_parameters(request) -> PaginationParameters: + arguments = request.args + page = arguments.get('page', 1, type=int) + per_page = arguments.get('per_page', 10, type=int) + order_by = arguments.get('order_by', type=str) + sort_dir = arguments.get('sort_dir', 'asc', type=str) + + return PaginationParameters(page, per_page, order_by, sort_dir) diff --git a/source/app/blueprints/rest/v2/cases/__init__.py b/source/app/blueprints/rest/v2/cases/__init__.py index 811c35352..1a78ae3fe 100644 --- a/source/app/blueprints/rest/v2/cases/__init__.py +++ b/source/app/blueprints/rest/v2/cases/__init__.py @@ -28,6 +28,8 @@ from app.blueprints.rest.endpoints import response_api_not_found from app.blueprints.rest.endpoints import response_api_created from app.blueprints.rest.endpoints import response_api_error +from app.blueprints.rest.endpoints import response_api_paginated +from app.blueprints.rest.parsing import parse_pagination_parameters from app.blueprints.rest.v2.cases.assets import case_assets_blueprint from app.blueprints.rest.v2.cases.iocs import case_iocs_blueprint from app.blueprints.rest.v2.cases.tasks import case_tasks_blueprint @@ -76,12 +78,9 @@ def get_cases() -> Response: Handles getting cases, with optional filtering & pagination """ - page = request.args.get('page', 1, type=int) - per_page = request.args.get('per_page', 10, type=int) - case_ids_str = request.args.get( - 'case_ids', None, type=parse_comma_separated_identifiers) - order_by = request.args.get('order_by', type=str) - sort_dir = request.args.get('sort_dir', 'asc', type=str) + pagination_parameters = parse_pagination_parameters(request) + + case_ids_str = request.args.get('case_ids', None, type=parse_comma_separated_identifiers) case_customer_id = request.args.get('case_customer_id', None, type=str) case_name = request.args.get('case_name', None, type=str) @@ -99,6 +98,8 @@ def get_cases() -> Response: is_open = request.args.get('is_open', None, type=parse_boolean) filtered_cases = get_filtered_cases( + current_user.id, + pagination_parameters, case_ids=case_ids_str, case_customer_id=case_customer_id, case_name=case_name, @@ -112,25 +113,13 @@ def get_cases() -> Response: start_open_date=start_open_date, end_open_date=end_open_date, search_value='', - page=page, - per_page=per_page, - current_user_id=current_user.id, - sort_by=order_by, - sort_dir=sort_dir, is_open=is_open ) if filtered_cases is None: return response_api_error('Filtering error') - cases = { - 'total': filtered_cases.total, - 'data': CaseSchemaForAPIV2().dump(filtered_cases.items, many=True), - 'last_page': filtered_cases.pages, - 'current_page': filtered_cases.page, - 'next_page': filtered_cases.next_num if filtered_cases.has_next else None, - } - - return response_api_success(data=cases) + case_schema = CaseSchemaForAPIV2() + return response_api_paginated(case_schema, filtered_cases) @cases_blueprint.get('/') diff --git a/source/app/blueprints/rest/v2/cases/iocs.py b/source/app/blueprints/rest/v2/cases/iocs.py index 45e26b2a6..2b2d0f014 100644 --- a/source/app/blueprints/rest/v2/cases/iocs.py +++ b/source/app/blueprints/rest/v2/cases/iocs.py @@ -21,11 +21,19 @@ from flask import request from app.blueprints.access_controls import ac_api_requires -from app.blueprints.rest.endpoints import response_api_created, response_api_deleted, response_api_not_found +from app.blueprints.rest.endpoints import response_api_created +from app.blueprints.rest.endpoints import response_api_deleted +from app.blueprints.rest.endpoints import response_api_not_found from app.blueprints.rest.endpoints import response_api_error from app.blueprints.rest.endpoints import response_api_success -from app.business.errors import BusinessProcessingError, ObjectNotFoundError -from app.business.iocs import iocs_create, iocs_get, iocs_delete, iocs_update +from app.blueprints.rest.endpoints import response_api_paginated +from app.blueprints.rest.parsing import parse_pagination_parameters +from app.business.errors import BusinessProcessingError +from app.business.errors import ObjectNotFoundError +from app.business.iocs import iocs_create +from app.business.iocs import iocs_get +from app.business.iocs import iocs_delete +from app.business.iocs import iocs_update from app.datamgmt.case.case_iocs_db import get_filtered_iocs from app.iris_engine.access_control.utils import ac_fast_check_current_user_has_case_access from app.models.authorization import CaseAccessLevel @@ -44,10 +52,7 @@ def get_case_iocs(case_identifier): if not ac_fast_check_current_user_has_case_access(case_identifier, [CaseAccessLevel.read_only, CaseAccessLevel.full_access]): return ac_api_return_access_denied(caseid=case_identifier) - page = request.args.get('page', 1, type=int) - per_page = request.args.get('per_page', 10, type=int) - order_by = request.args.get('order_by', type=str) - sort_dir = request.args.get('sort_dir', 'asc', type=str) + pagination_parameters = parse_pagination_parameters(request) ioc_type_id = request.args.get('ioc_type_id', None, type=int) ioc_type = request.args.get('ioc_type', None, type=str) @@ -57,33 +62,21 @@ def get_case_iocs(case_identifier): ioc_tags = request.args.get('ioc_tags', None, type=str) filtered_iocs = get_filtered_iocs( + pagination_parameters, caseid=case_identifier, ioc_type_id=ioc_type_id, ioc_type=ioc_type, ioc_tlp_id=ioc_tlp_id, ioc_value=ioc_value, ioc_description=ioc_description, - ioc_tags=ioc_tags, - page=page, - per_page=per_page, - sort_by=order_by, - sort_dir=sort_dir + ioc_tags=ioc_tags ) if filtered_iocs is None: return response_api_error('Filtering error') - iocs = IocSchemaForAPIV2().dump(filtered_iocs.items, many=True) - - iocs = { - 'total': filtered_iocs.total, - 'data': iocs, - 'last_page': filtered_iocs.pages, - 'current_page': filtered_iocs.page, - 'next_page': filtered_iocs.next_num if filtered_iocs.has_next else None, - } - - return response_api_success(data=iocs) + iocs_schema = IocSchemaForAPIV2() + return response_api_paginated(iocs_schema, filtered_iocs) @case_iocs_blueprint.post('') diff --git a/source/app/blueprints/rest/v2/cases/tasks.py b/source/app/blueprints/rest/v2/cases/tasks.py index 826caac07..8e823fb3f 100644 --- a/source/app/blueprints/rest/v2/cases/tasks.py +++ b/source/app/blueprints/rest/v2/cases/tasks.py @@ -23,7 +23,9 @@ from app.blueprints.rest.endpoints import response_api_not_found from app.blueprints.rest.endpoints import response_api_deleted from app.blueprints.rest.endpoints import response_api_success +from app.blueprints.rest.endpoints import response_api_paginated from app.blueprints.rest.endpoints import response_api_created +from app.blueprints.rest.parsing import parse_pagination_parameters from app.blueprints.access_controls import ac_api_return_access_denied from app.blueprints.access_controls import ac_api_requires from app.schema.marshables import CaseTaskSchema @@ -33,6 +35,7 @@ from app.business.tasks import tasks_get from app.business.tasks import tasks_update from app.business.tasks import tasks_delete +from app.business.tasks import tasks_filter from app.models.authorization import CaseAccessLevel from app.iris_engine.access_control.utils import ac_fast_check_current_user_has_case_access @@ -60,6 +63,21 @@ def add_case_task(case_identifier): return response_api_error(e.get_message()) +@case_tasks_blueprint.get('') +@ac_api_requires() +def case_get_tasks(case_identifier): + + if not ac_fast_check_current_user_has_case_access(case_identifier, [CaseAccessLevel.read_only, CaseAccessLevel.full_access]): + return ac_api_return_access_denied(caseid=case_identifier) + + pagination_parameters = parse_pagination_parameters(request) + + tasks = tasks_filter(case_identifier, pagination_parameters) + + task_schema = CaseTaskSchema() + return response_api_paginated(task_schema, tasks) + + @case_tasks_blueprint.get('/') @ac_api_requires() def get_case_task(case_identifier, identifier): diff --git a/source/app/business/tasks.py b/source/app/business/tasks.py index 82896c11c..44ac1e4c1 100644 --- a/source/app/business/tasks.py +++ b/source/app/business/tasks.py @@ -18,6 +18,7 @@ from datetime import datetime +from flask_sqlalchemy.pagination import Pagination from flask_login import current_user from app import db @@ -25,10 +26,12 @@ from app.datamgmt.case.case_tasks_db import add_task from app.datamgmt.case.case_tasks_db import update_task_assignees from app.datamgmt.case.case_tasks_db import get_task +from app.datamgmt.case.case_tasks_db import get_filtered_tasks from app.datamgmt.states import update_tasks_state from app.iris_engine.module_handler.module_handler import call_modules_hook from app.iris_engine.utils.tracker import track_activity from app.models.models import CaseTasks +from app.models.pagination_parameters import PaginationParameters from app.schema.marshables import CaseTaskSchema from app.business.errors import BusinessProcessingError from app.business.errors import ObjectNotFoundError @@ -84,6 +87,10 @@ def tasks_get(identifier) -> CaseTasks: return task +def tasks_filter(case_identifier, pagination_parameters: PaginationParameters) -> Pagination: + return get_filtered_tasks(case_identifier, pagination_parameters) + + def tasks_update(task: CaseTasks, request_json): case_identifier = task.task_case_id request_data = call_modules_hook('on_preload_task_update', data=request_json, caseid=case_identifier) diff --git a/source/app/datamgmt/case/case_iocs_db.py b/source/app/datamgmt/case/case_iocs_db.py index 355525a9a..708af2448 100644 --- a/source/app/datamgmt/case/case_iocs_db.py +++ b/source/app/datamgmt/case/case_iocs_db.py @@ -35,6 +35,7 @@ from app.models.authorization import User from app.models.authorization import UserCaseEffectiveAccess from app.models.authorization import CaseAccessLevel +from app.models.pagination_parameters import PaginationParameters def get_iocs(case_identifier) -> list[Ioc]: @@ -343,24 +344,22 @@ def _build_filter_ioc_query( def get_filtered_iocs( + pagination_parameters: PaginationParameters, caseid: int = None, ioc_type_id: int = None, ioc_type: str = None, ioc_tlp_id: int = None, ioc_value: str = None, ioc_description: str = None, - ioc_tags: str = None, - per_page: int = None, - page: int = None, - sort_by=None, - sort_dir='asc' + ioc_tags: str = None ): query = _build_filter_ioc_query(caseid=caseid, ioc_type_id=ioc_type_id, ioc_type=ioc_type, ioc_tlp_id=ioc_tlp_id, ioc_value=ioc_value, - ioc_description=ioc_description, ioc_tags=ioc_tags, sort_by=sort_by, sort_dir=sort_dir) + ioc_description=ioc_description, ioc_tags=ioc_tags, + sort_by=pagination_parameters.get_order_by(), sort_dir=pagination_parameters.get_direction()) try: - filtered_iocs = query.paginate(page=page, per_page=per_page, error_out=False) + filtered_iocs = query.paginate(page=pagination_parameters.get_page(), per_page=pagination_parameters.get_per_page(), error_out=False) except Exception as e: app.logger.exception(f"Error getting cases: {str(e)}") diff --git a/source/app/datamgmt/case/case_tasks_db.py b/source/app/datamgmt/case/case_tasks_db.py index d3097cc35..bef9fd4fb 100644 --- a/source/app/datamgmt/case/case_tasks_db.py +++ b/source/app/datamgmt/case/case_tasks_db.py @@ -22,6 +22,7 @@ from sqlalchemy import and_ from app import db +from app.datamgmt.conversions import convert_sort_direction from app.datamgmt.manage.manage_attribute_db import get_default_custom_attributes from app.datamgmt.manage.manage_users_db import get_users_list_restricted_from_case from app.datamgmt.states import update_tasks_state @@ -32,15 +33,36 @@ from app.models.models import TaskComments from app.models.models import TaskStatus from app.models.authorization import User +from app.models.pagination_parameters import PaginationParameters def get_tasks_status(): return TaskStatus.query.all() -def get_tasks(caseid): - return CaseTasks.query.with_entities( - CaseTasks.id.label("task_id"), +def get_filtered_tasks(case_identifier, pagination_parameters: PaginationParameters): + + query = CaseTasks.query.filter( + CaseTasks.task_case_id == case_identifier + ).join( + CaseTasks.status + ).order_by( + desc(TaskStatus.status_name) + ) + + sort_by = pagination_parameters.get_order_by() + if sort_by is not None: + order_func = convert_sort_direction(pagination_parameters.get_direction()) + + if hasattr(CaseTasks, sort_by): + query = query.order_by(order_func(getattr(CaseTasks, sort_by))) + + return query.paginate(page=pagination_parameters.get_page(), per_page=pagination_parameters.get_per_page(), error_out=False) + + +def get_tasks_with_assignees(caseid): + tasks = CaseTasks.query.with_entities( + CaseTasks.id.label('task_id'), CaseTasks.task_uuid, CaseTasks.task_title, CaseTasks.task_description, @@ -56,10 +78,6 @@ def get_tasks(caseid): ).order_by( desc(TaskStatus.status_name) ).all() - - -def get_tasks_with_assignees(caseid): - tasks = get_tasks(caseid) if not tasks: return None diff --git a/source/app/datamgmt/conversions.py b/source/app/datamgmt/conversions.py new file mode 100644 index 000000000..b09650ad1 --- /dev/null +++ b/source/app/datamgmt/conversions.py @@ -0,0 +1,27 @@ +# IRIS Source Code +# Copyright (C) 2024 - DFIR-IRIS +# contact@dfir-iris.org +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 3 of the License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program; if not, write to the Free Software Foundation, +# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + +from sqlalchemy import asc +from sqlalchemy import desc + + +def convert_sort_direction(sort_direction): + if sort_direction == 'desc': + return desc + else: + return asc \ No newline at end of file diff --git a/source/app/datamgmt/manage/manage_assets_db.py b/source/app/datamgmt/manage/manage_assets_db.py index afdf7821d..a5575c855 100644 --- a/source/app/datamgmt/manage/manage_assets_db.py +++ b/source/app/datamgmt/manage/manage_assets_db.py @@ -1,9 +1,10 @@ from flask_login import current_user -from sqlalchemy import and_, desc, asc +from sqlalchemy import and_ from functools import reduce import app from app.datamgmt.manage.manage_cases_db import user_list_cases_view +from app.datamgmt.conversions import convert_sort_direction from app.models.cases import Cases from app.models.models import CaseAssets from app.models.models import Client @@ -57,7 +58,7 @@ def get_filtered_assets(case_id=None, data = data.join(CaseAssets.case).join(Cases.client) if sort_by is not None: - order_func = desc if sort_dir == 'desc' else asc + order_func = convert_sort_direction(sort_dir) if sort_by == 'name': data = data.order_by(order_func(CaseAssets.asset_name)) diff --git a/source/app/datamgmt/manage/manage_cases_db.py b/source/app/datamgmt/manage/manage_cases_db.py index 6b80fec52..650459e97 100644 --- a/source/app/datamgmt/manage/manage_cases_db.py +++ b/source/app/datamgmt/manage/manage_cases_db.py @@ -21,8 +21,6 @@ from pathlib import Path from sqlalchemy import and_ -from sqlalchemy import desc -from sqlalchemy import asc from sqlalchemy.orm import aliased from functools import reduce @@ -31,6 +29,7 @@ from app.datamgmt.alerts.alerts_db import search_alert_resolution_by_name from app.datamgmt.case.case_db import get_case_tags from app.datamgmt.manage.manage_case_state_db import get_case_state_by_name +from app.datamgmt.conversions import convert_sort_direction from app.datamgmt.authorization import has_deny_all_access_level from app.datamgmt.states import delete_case_states from app.models.models import CaseAssets @@ -67,6 +66,7 @@ from app.models.cases import CaseProtagonist from app.models.cases import CaseTags from app.models.cases import CaseState +from app.models.pagination_parameters import PaginationParameters def list_cases_id(): @@ -487,7 +487,7 @@ def build_filter_case_query(current_user_id, return query.join(Tags, Tags.tag_title.ilike(f'%{case_tags}%')).filter(CaseTags.case_id == Cases.case_id) if sort_by is not None: - order_func = desc if sort_dir == 'desc' else asc + order_func = convert_sort_direction(sort_dir) if sort_by == 'owner': query = query.join(User, Cases.owner_id == User.id).order_by(order_func(User.name)) @@ -507,6 +507,7 @@ def build_filter_case_query(current_user_id, def get_filtered_cases(current_user_id, + pagination_parameters: PaginationParameters, start_open_date: str = None, end_open_date: str = None, case_customer_id: int = None, @@ -520,23 +521,19 @@ def get_filtered_cases(current_user_id, case_state_id: int = None, case_soc_id: str = None, case_open_since: int = None, - per_page: int = None, - page: int = None, search_value=None, - sort_by=None, - sort_dir='asc', is_open: bool = None ): data = build_filter_case_query(case_classification_id=case_classification_id, case_customer_id=case_customer_id, case_description=case_description, case_ids=case_ids, case_name=case_name, case_opening_user_id=case_opening_user_id, case_owner_id=case_owner_id, case_severity_id=case_severity_id, case_soc_id=case_soc_id, case_open_since=case_open_since, case_state_id=case_state_id, current_user_id=current_user_id, end_open_date=end_open_date, - search_value=search_value, sort_by=sort_by, sort_dir=sort_dir, start_open_date=start_open_date, - is_open=is_open) + search_value=search_value, start_open_date=start_open_date, is_open=is_open, + sort_by=pagination_parameters.get_order_by(), sort_dir=pagination_parameters.get_direction()) try: - filtered_cases = data.paginate(page=page, per_page=per_page, error_out=False) + filtered_cases = data.paginate(page=pagination_parameters.get_page(), per_page=pagination_parameters.get_per_page(), error_out=False) except Exception as e: app.logger.exception(f'Error getting cases: {str(e)}') diff --git a/source/app/datamgmt/manage/manage_tags_db.py b/source/app/datamgmt/manage/manage_tags_db.py index 3e66b5bfd..0c9a7cd53 100644 --- a/source/app/datamgmt/manage/manage_tags_db.py +++ b/source/app/datamgmt/manage/manage_tags_db.py @@ -1,9 +1,10 @@ from functools import reduce -from sqlalchemy import and_, desc, asc +from sqlalchemy import and_ import app from app.models.models import Tags +from app.datamgmt.conversions import convert_sort_direction def get_filtered_tags(tag_title=None, @@ -37,7 +38,7 @@ def get_filtered_tags(tag_title=None, data = Tags.query.filter(*conditions) if sort_by is not None: - order_func = desc if sort_dir == 'desc' else asc + order_func = convert_sort_direction(sort_dir) if sort_by == 'name': data = data.order_by(order_func(Tags.tag_title)) diff --git a/source/app/datamgmt/manage/manage_users_db.py b/source/app/datamgmt/manage/manage_users_db.py index cb65d3183..0e12a87b1 100644 --- a/source/app/datamgmt/manage/manage_users_db.py +++ b/source/app/datamgmt/manage/manage_users_db.py @@ -15,17 +15,20 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program; if not, write to the Free Software Foundation, # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + from typing import List from functools import reduce from flask_login import current_user -from sqlalchemy import and_, desc, asc +from sqlalchemy import and_ import app from app import bc from app import db from app.datamgmt.case.case_db import get_case -from app.iris_engine.access_control.utils import ac_access_level_mask_from_val_list, ac_ldp_group_removal +from app.datamgmt.conversions import convert_sort_direction +from app.iris_engine.access_control.utils import ac_access_level_mask_from_val_list +from app.iris_engine.access_control.utils import ac_ldp_group_removal from app.iris_engine.access_control.utils import ac_access_level_to_list from app.iris_engine.access_control.utils import ac_auto_update_user_effective_access from app.iris_engine.access_control.utils import ac_get_detailed_effective_permissions_from_groups @@ -750,7 +753,7 @@ def get_filtered_users(user_ids: str = None, if len(conditions) > 1: conditions = [reduce(and_, conditions)] - order_func = desc if sort == 'desc' else asc + order_func = convert_sort_direction(sort) try: diff --git a/source/app/models/pagination_parameters.py b/source/app/models/pagination_parameters.py new file mode 100644 index 000000000..e0937cb45 --- /dev/null +++ b/source/app/models/pagination_parameters.py @@ -0,0 +1,37 @@ +# IRIS Source Code +# Copyright (C) 2024 - DFIR-IRIS +# contact@dfir-iris.org +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 3 of the License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program; if not, write to the Free Software Foundation, +# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + +class PaginationParameters: + + def __init__(self, page, per_page, order_by, direction): + self._page = page + self._per_page = per_page + self._order_by = order_by + self._direction = direction + + def get_page(self): + return self._page + + def get_per_page(self): + return self._per_page + + def get_order_by(self): + return self._order_by + + def get_direction(self): + return self._direction diff --git a/tests/tests_rest_tasks.py b/tests/tests_rest_tasks.py index 2bc06e2b6..e894c2d6c 100644 --- a/tests/tests_rest_tasks.py +++ b/tests/tests_rest_tasks.py @@ -142,3 +142,52 @@ def test_update_task_should_return_a_task(self): response = self._subject.update(f'/api/v2/cases/{case_identifier}/tasks/{identifier}', {'task_title': 'new title', 'task_status_id': 1, 'task_assignees_id': []}).json() self.assertEqual('new title', response['task_title']) + + def test_get_tasks_should_return_200(self): + case_identifier = self._subject.create_dummy_case() + response = self._subject.get(f'/api/v2/cases/{case_identifier}/tasks') + self.assertEqual(200, response.status_code) + + def test_get_tasks_should_return_empty_list_for_field_data_when_there_are_no_tasks(self): + case_identifier = self._subject.create_dummy_case() + response = self._subject.get(f'/api/v2/cases/{case_identifier}/tasks').json() + self.assertEqual([], response['data']) + + def test_get_tasks_should_return_total(self): + case_identifier = self._subject.create_dummy_case() + body = {'task_assignees_id': [], 'task_status_id': 1, 'task_title': 'dummy title'} + self._subject.create(f'/api/v2/cases/{case_identifier}/tasks', body).json() + response = self._subject.get(f'/api/v2/cases/{case_identifier}/tasks').json() + self.assertEqual(1, response['total']) + + def test_get_tasks_should_honour_per_page_pagination_parameter(self): + case_identifier = self._subject.create_dummy_case() + body = {'task_assignees_id': [], 'task_status_id': 1, 'task_title': 'task1'} + self._subject.create(f'/api/v2/cases/{case_identifier}/tasks', body).json() + body = {'task_assignees_id': [], 'task_status_id': 1, 'task_title': 'task2'} + self._subject.create(f'/api/v2/cases/{case_identifier}/tasks', body).json() + body = {'task_assignees_id': [], 'task_status_id': 1, 'task_title': 'task3'} + self._subject.create(f'/api/v2/cases/{case_identifier}/tasks', body).json() + response = self._subject.get(f'/api/v2/cases/{case_identifier}/tasks', { 'per_page': 2 }).json() + self.assertEqual(2, len(response['data'])) + + def test_get_tasks_should_return_current_page(self): + case_identifier = self._subject.create_dummy_case() + body = {'task_assignees_id': [], 'task_status_id': 1, 'task_title': 'task1'} + self._subject.create(f'/api/v2/cases/{case_identifier}/tasks', body).json() + body = {'task_assignees_id': [], 'task_status_id': 1, 'task_title': 'task2'} + self._subject.create(f'/api/v2/cases/{case_identifier}/tasks', body).json() + body = {'task_assignees_id': [], 'task_status_id': 1, 'task_title': 'task3'} + self._subject.create(f'/api/v2/cases/{case_identifier}/tasks', body).json() + response = self._subject.get(f'/api/v2/cases/{case_identifier}/tasks', { 'page': 2, 'per_page': 2 }).json() + self.assertEqual(2, response['current_page']) + + def test_get_tasks_should_return_correct_task_uuid(self): + case_identifier = self._subject.create_dummy_case() + body = {'task_assignees_id': [], 'task_status_id': 1, 'task_title': 'title'} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/tasks', body).json() + identifier = response['id'] + response = self._subject.get(f'/api/v2/tasks/{identifier}').json() + expected_uuid = response['task_uuid'] + response = self._subject.get(f'/api/v2/cases/{case_identifier}/tasks').json() + self.assertEqual(expected_uuid, response['data'][0]['task_uuid'])