diff --git a/.dmypy.json b/.dmypy.json new file mode 100644 index 0000000..f9027c8 --- /dev/null +++ b/.dmypy.json @@ -0,0 +1 @@ +{"pid": 208431, "connection_name": "/tmp/tmpccrhv4cx/dmypy.sock"} diff --git a/.env.example b/.env.example index 2badb13..c3e5e9c 100644 --- a/.env.example +++ b/.env.example @@ -19,7 +19,7 @@ amcat4_elastic_verify_ssl=False amcat4_auth=no_auth # Middlecat server to trust as ID provider -amcat4_middlecat_url=https://middlecat.up.railway.app +amcat4_middlecat_url=https://middlecat.net # Email address for a hardcoded admin email (useful for setup and recovery) #amcat4_admin_email= diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index 7bd5ca3..3985326 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -5,13 +5,12 @@ name: Flake8 on: push: - branches: [ master ] + branches: [master] pull_request: - branches: [ master ] + branches: [master] jobs: build: - runs-on: ubuntu-latest strategy: fail-fast: false @@ -19,17 +18,17 @@ jobs: python-version: ["3.8", "3.9"] steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - pip install -e .[dev] - - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude=env - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --max-line-length=127 --statistics --exclude=env + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + pip install -e .[dev] + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude=env + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --max-line-length=127 --ignore=E203 --statistics --exclude=env diff --git a/.gitignore b/.gitignore index 82bb2ac..a3e54ee 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ # .env environment variables .env +.venv # C extensions *.so @@ -49,8 +50,10 @@ nosetests.xml # PyCharm meuk .idea -# vscode meuk -.vscode +# vscode meuk (only include extensions and settings) +.vscode/* +!.vscode/settings.json +!.vscode/extensions.json # static files navigator/media/static diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 0000000..21309ef --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,7 @@ +{ + "recommendations": [ + "matangover.mypy", + "ms-python.python", + "ms-python.black-formatter" + ] +} diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..412dd90 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,14 @@ +{ + "[python]": { + "editor.defaultFormatter": "ms-python.black-formatter", + "editor.formatOnSave": true + }, + + "black-formatter.args": ["--line-length", "127"], + "mypy.enabled": true, + "mypy.runUsingActiveInterpreter": true, + "python.analysis.typeCheckingMode": "basic", + + "python.analysis.autoImportCompletions": true, + "flake8.args": ["--max-line-length=127", "--ignore=E203"] +} diff --git a/amcat4/__main__.py b/amcat4/__main__.py index 6873810..1ae98de 100644 --- a/amcat4/__main__.py +++ b/amcat4/__main__.py @@ -1,32 +1,32 @@ """ AmCAT4 REST API """ + import argparse -import collections import csv import io import json import logging import os +from pathlib import Path import secrets import sys +from typing import Any import urllib.request from enum import Enum -import elasticsearch +from collections import defaultdict +import elasticsearch.helpers import uvicorn from pydantic.fields import FieldInfo +from uvicorn.config import LOGGING_CONFIG + from amcat4 import index from amcat4.config import get_settings, AuthOptions, validate_settings -from amcat4.elastic import connect_elastic, get_system_version, ping, upload_documents -from amcat4.index import ( - GLOBAL_ROLES, - create_index, - set_global_role, - Role, - list_global_users, -) +from amcat4.elastic import connect_elastic, get_system_version, ping +from amcat4.index import GLOBAL_ROLES, create_index, set_global_role, Role, list_global_users, upload_documents +from amcat4.models import ElasticType, FieldType SOTU_INDEX = "state_of_the_union" @@ -51,20 +51,17 @@ def upload_test_data() -> str: ) for row in csvfile ] - columns = {"president": "keyword", "party": "keyword", "year": "double"} - upload_documents(SOTU_INDEX, docs, columns) + fields: dict[str, FieldType] = {"president": "keyword", "party": "keyword", "year": "integer"} + upload_documents(SOTU_INDEX, docs, fields) return SOTU_INDEX def run(args): auth = get_settings().auth - logging.info( - f"Starting server at port {args.port}, debug={not args.nodebug}, auth={auth}" - ) + logging.info(f"Starting server at port {args.port}, debug={not args.nodebug}, auth={auth}") if auth == AuthOptions.no_auth: logging.warning( - "Warning: No authentication is set up - " - "everyone who can access this service can view and change all data" + "Warning: No authentication is set up - " "everyone who can access this service can view and change all data" ) if validate_settings(): logging.warning(validate_settings()) @@ -75,9 +72,8 @@ def run(args): ) if ping(): logging.info(f"Connect to elasticsearch {get_settings().elastic_host}") - uvicorn.run( - "amcat4.api:app", host="0.0.0.0", reload=not args.nodebug, port=args.port - ) + log_config = "logging.yml" if Path("logging.yml").exists() else LOGGING_CONFIG + uvicorn.run("amcat4.api:app", host="0.0.0.0", reload=not args.nodebug, port=args.port, log_config=log_config) def val(val_or_list): @@ -88,31 +84,26 @@ def val(val_or_list): return val_or_list -def migrate_index(_args): +def migrate_index(_args) -> None: settings = get_settings() elastic = connect_elastic() if not elastic.ping(): logging.error(f"Cannot connect to elasticsearch server {settings.elastic_host}") sys.exit(1) if not elastic.indices.exists(index=settings.system_index): - logging.info( - "System index does not exist yet. It will be created automatically if you run the server" - ) + logging.info("System index does not exist yet. It will be created automatically if you run the server") sys.exit(1) # Check index format version version = get_system_version(elastic) - logging.info( - f"{settings.elastic_host}::{settings.system_index} is at version {version or 0}" - ) + logging.info(f"{settings.elastic_host}::{settings.system_index} is at version {version or 0}") if version == 1: logging.info("Nothing to do") else: logging.info("Migrating to version 1") fields = ["index", "email", "role"] - indices = collections.defaultdict(dict) - for entry in elasticsearch.helpers.scan( - elastic, index=settings.system_index, fields=fields, _source=False - ): + indices: defaultdict[str, dict[str, str]] = defaultdict(dict) + + for entry in elasticsearch.helpers.scan(elastic, index=settings.system_index, fields=fields, _source=False): index, email, role = [val(entry["fields"][field]) for field in fields] indices[index][email] = role if GLOBAL_ROLES not in indices: @@ -121,15 +112,11 @@ def migrate_index(_args): elastic.indices.delete(index=settings.system_index) for index, roles_dict in indices.items(): guest_role = roles_dict.pop("_guest", None) - roles_dict.pop("admin", None) - roles = [ - {"email": email, "role": role} - for (email, role) in roles_dict.items() - ] - doc = dict(name=index, guest_role=guest_role, roles=roles) + roles_dict.pop("ADMIN", None) + roles = [{"email": email, "role": role} for (email, role) in roles_dict.items()] + doc: dict[str, Any] = dict(name=index, guest_role=guest_role, roles=roles) if index == GLOBAL_ROLES: doc["version"] = 1 - print(doc) elastic.index(index=settings.system_index, id=index, document=doc) except Exception: try: @@ -148,7 +135,7 @@ def migrate_index(_args): def base_env(): return dict( amcat4_secret_key=secrets.token_hex(nbytes=32), - amcat4_middlecat_url="https://middlecat.up.netlify.app", + amcat4_middlecat_url="https://middlecat.net", ) @@ -160,10 +147,6 @@ def create_env(args): env = base_env() if args.admin_email: env["amcat4_admin_email"] = args.admin_email - if args.admin_password: - env["amcat4_admin_password"] = args.admin_password - if args.no_admin_password: - env["amcat4_admin_password"] = "" with open(".env", "w") as f: for key, val in env.items(): f.write(f"{key}={val}\n") @@ -183,45 +166,39 @@ def add_admin(args): def list_users(_args): - admin_password = get_settings().admin_password - if admin_password: - print("ADMIN : admin (password set via environment AMCAT4_ADMIN_PASSWORD)") - users = sorted(list_global_users(), key=lambda ur: (ur[1], ur[0])) + users = list_global_users() + + # sorted changes the output type of list_global_users? + # users = sorted(list_global_users(), key=lambda ur: (ur[1], ur[0])) if users: - for user, role in users: + for user, role in users.items(): print(f"{role.name:10}: {user}") - if not (users or admin_password): - print( - "(No users defined yet, set AMCAT4_ADMIN_PASSWORD in environment use add-admin to add users by email)" - ) + if not users: + print("(No users defined yet, use add-admin to add users by email)") def config_amcat(args): settings = get_settings() - settings_dict = settings.model_dump() # Not a useful entry in an actual env_file - env_file_location = settings_dict.pop("env_file") - print(f"Reading/writing settings from {env_file_location}") - for fieldname in settings.model_fields_set: - if fieldname not in settings_dict: + print(f"Reading/writing settings from {settings.env_file}") + for fieldname, fieldinfo in settings.model_fields.items(): + if fieldname == "env_file": continue - fieldinfo = settings.model_fields[fieldname] + validation_function = AuthOptions.validate if fieldname == "auth" else None value = getattr(settings, fieldname) - value = menu( - fieldname, fieldinfo, value, validation_function=validation_function - ) + value = menu(fieldname, fieldinfo, value, validation_function=validation_function) if value is ABORTED: return if value is not UNCHANGED: - settings_dict[fieldname] = value + setattr(settings, fieldname, value) - with env_file_location.open("w") as f: - for fieldname, value in settings_dict.items(): - fieldinfo = settings.model_fields[fieldname] + with settings.env_file.open("w") as f: + for fieldname, fieldinfo in settings.model_fields.items(): + value = getattr(settings, fieldname) if doc := fieldinfo.description: f.write(f"# {doc}\n") - if _isenum(fieldinfo): + if _isenum(fieldinfo) and fieldinfo.annotation: f.write("# Valid options:\n") for option in fieldinfo.annotation: doc = option.__doc__.replace("\n", " ") @@ -231,7 +208,7 @@ def config_amcat(args): else: f.write(f"amcat4_{fieldname}={value}\n\n") os.chmod(".env", 0o600) - print(f"*** Written {bold('.env')} file to {env_file_location} ***") + print(f"*** Written {bold('.env')} file to {settings.env_file} ***") def bold(x): @@ -244,24 +221,23 @@ def bold(x): def _isenum(fieldinfo: FieldInfo) -> bool: try: - return issubclass(fieldinfo.annotation, Enum) + return issubclass(fieldinfo.annotation, Enum) if fieldinfo.annotation is not None else False except TypeError: return False def menu(fieldname: str, fieldinfo: FieldInfo, value, validation_function=None): print(f"\n{bold(fieldname)}: {fieldinfo.description}") - if _isenum(fieldinfo): + if _isenum(fieldinfo) and fieldinfo.annotation: print(" Possible choices:") - for option in fieldinfo.annotation: + options: Any = fieldinfo.annotation + for option in options: print(f" - {option.name}: {option.__doc__}") print() print(f"The current value for {bold(fieldname)} is {bold(value)}.") while True: try: - value = input( - "Enter a new value, press [enter] to leave unchanged, or press [control+c] to abort: " - ) + value = input("Enter a new value, press [enter] to leave unchanged, or press [control+c] to abort: ") except KeyboardInterrupt: return ABORTED if not value.strip(): @@ -275,9 +251,7 @@ def menu(fieldname: str, fieldinfo: FieldInfo, value, validation_function=None): def main(): parser = argparse.ArgumentParser(description=__doc__, prog="python -m amcat4") - subparsers = parser.add_subparsers( - dest="action", title="action", help="Action to perform:", required=True - ) + subparsers = parser.add_subparsers(dest="action", title="action", help="Action to perform:", required=True) p = subparsers.add_parser("run", help="Run the backend API in development mode") p.add_argument( "--no-debug", @@ -288,22 +262,11 @@ def main(): p.add_argument("-p", "--port", help="Port", default=5000) p.set_defaults(func=run) - p = subparsers.add_parser( - "create-env", help="Create the .env file with a random secret key" - ) + p = subparsers.add_parser("create-env", help="Create the .env file with a random secret key") p.add_argument("-a", "--admin_email", help="The email address of the admin user.") - p.add_argument( - "-p", "--admin_password", help="The password of the built-in admin user." - ) - p.add_argument( - "-P", "--no-admin_password", action="store_true", help="Disable admin password" - ) - p.set_defaults(func=create_env) - p = subparsers.add_parser( - "config", help="Configure amcat4 settings in an interactive menu." - ) + p = subparsers.add_parser("config", help="Configure amcat4 settings in an interactive menu.") p.set_defaults(func=config_amcat) p = subparsers.add_parser("add-admin", help="Add a global admin") @@ -313,24 +276,17 @@ def main(): p = subparsers.add_parser("list-users", help="List global users") p.set_defaults(func=list_users) - p = subparsers.add_parser( - "create-test-index", help=f"Create the {SOTU_INDEX} test index" - ) + p = subparsers.add_parser("create-test-index", help=f"Create the {SOTU_INDEX} test index") p.set_defaults(func=create_test_index) - p = subparsers.add_parser( - "migrate", help="Migrate the system index to the current version" - ) + p = subparsers.add_parser("migrate", help="Migrate the system index to the current version") p.set_defaults(func=migrate_index) args = parser.parse_args() - logging.basicConfig( - format="[%(levelname)-7s:%(name)-15s] %(message)s", level=logging.INFO - ) + logging.basicConfig(format="[%(levelname)-7s:%(name)-15s] %(message)s", level=logging.INFO) es_logger = logging.getLogger("elasticsearch") es_logger.setLevel(logging.WARNING) - args.func(args) diff --git a/amcat4/aggregate.py b/amcat4/aggregate.py index 4845f2c..782bdf0 100644 --- a/amcat4/aggregate.py +++ b/amcat4/aggregate.py @@ -1,12 +1,16 @@ """ Aggregate queries """ + +import copy from datetime import datetime -from typing import Mapping, Iterable, Union, Tuple, Sequence, List, Dict, Optional +from typing import Any, Mapping, Iterable, Union, Tuple, Sequence, List, Dict from amcat4.date_mappings import interval_mapping -from amcat4.elastic import es, get_fields -from amcat4.query import build_body, _normalize_queries +from amcat4.elastic import es +from amcat4.fields import get_fields +from amcat4.query import build_body +from amcat4.models import Field, FilterSpec def _combine_mappings(mappings): @@ -21,7 +25,8 @@ class Axis: """ Class that specifies an aggregation axis """ - def __init__(self, field: str, interval: str = None, name: str = None, field_type: str = None): + + def __init__(self, field: str, interval: str | None = None, name: str | None = None, field_type: str | None = None): self.field = field self.interval = interval self.ftype = field_type @@ -33,7 +38,7 @@ def __init__(self, field: str, interval: str = None, name: str = None, field_typ self.name = field def __repr__(self): - return f"" + return f"" def query(self): if not self.ftype: @@ -46,14 +51,14 @@ def query(self): else: return {self.name: {"histogram": {"field": self.field, "interval": self.interval}}} else: - return {self.name: {"terms": {"field": self.field}}} + return {self.name: {"terms": {"field": self.field, "order": "desc"}}} def get_value(self, values): value = values[self.name] if m := interval_mapping(self.interval): value = m.postprocess(value) elif self.ftype == "date": - value = datetime.utcfromtimestamp(value / 1000.) + value = datetime.utcfromtimestamp(value / 1000.0) if self.interval in {"year", "month", "week", "day"}: value = value.date() return value @@ -70,7 +75,8 @@ class Aggregation: """ Specification of a single aggregation, that is, field and aggregation function """ - def __init__(self, field: str, function: str, name: str = None, ftype: str = None): + + def __init__(self, field: str, function: str, name: str | None = None, ftype: str | None = None): self.field = field self.function = function self.name = name or f"{function}_{field}" @@ -80,9 +86,9 @@ def dsl_item(self): return self.name, {self.function: {"field": self.field}} def get_value(self, bucket: dict): - result = bucket[self.name]['value'] + result = bucket[self.name]["value"] if result and self.ftype == "date": - result = datetime.utcfromtimestamp(result / 1000.) + result = datetime.utcfromtimestamp(result / 1000.0) return result def asdict(self): @@ -95,95 +101,179 @@ def aggregation_dsl(aggregations: Iterable[Aggregation]) -> dict: class AggregateResult: - def __init__(self, axes: Sequence[Axis], aggregations: List[Aggregation], - data: List[tuple], count_column: str = "n"): + def __init__( + self, + axes: Sequence[Axis], + aggregations: List[Aggregation], + data: List[tuple], + count_column: str = "n", + after: dict | None = None, + ): self.axes = axes self.data = data self.aggregations = aggregations self.count_column = count_column + self.after = after def as_dicts(self) -> Iterable[dict]: """Return the results as a sequence of {axis1, ..., n} dicts""" - keys = tuple(ax.name for ax in self.axes) + (self.count_column, ) + keys = tuple(ax.name for ax in self.axes) + (self.count_column,) if self.aggregations: keys += tuple(a.name for a in self.aggregations) for row in self.data: yield dict(zip(keys, row)) -def _bare_aggregate(index: str, queries, filters, aggregations: Sequence[Aggregation]) -> Tuple[int, dict]: +def _bare_aggregate(index: str | list[str], queries, filters, aggregations: Sequence[Aggregation]) -> Tuple[int, dict]: """ Aggregate without sources/group_by. Returns a tuple of doc count and aggregegations (doc_count, {metric: value}) """ body = build_body(queries=queries, filters=filters) if filters or queries else {} + index = index if isinstance(index, str) else ",".join(index) aresult = es().search(index=index, size=0, aggregations=aggregation_dsl(aggregations), **body) cresult = es().count(index=index, **body) - return cresult['count'], aresult['aggregations'] + return cresult["count"], aresult["aggregations"] -def _elastic_aggregate(index: Union[str, List[str]], sources, queries, filters, aggregations: Sequence[Aggregation], - runtime_mappings: Mapping[str, Mapping] = None, after_key=None) -> Iterable[dict]: +def _elastic_aggregate( + index: str | list[str], + sources, + axes, + queries, + filters, + aggregations: list[Aggregation], + runtime_mappings: dict[str, Mapping] | None = None, + after_key=None, +) -> Tuple[list, dict | None]: """ Recursively get all buckets from a composite query. Yields 'buckets' consisting of {key: {axis: value}, doc_count: } """ # [WvA] Not sure if we should get all results ourselves or expose the 'after' pagination. # This might get us in trouble if someone e.g. aggregates on url or day for a large corpus - after = {"after": after_key} if after_key else {} + after = {"after": after_key} if after_key is not None and len(after_key) > 0 else {} aggr: Dict[str, Dict[str, dict]] = {"aggs": {"composite": dict(sources=sources, **after)}} if aggregations: - aggr["aggs"]['aggregations'] = aggregation_dsl(aggregations) + aggr["aggs"]["aggregations"] = aggregation_dsl(aggregations) kargs = {} + if filters or queries: - q = build_body(queries=queries.values(), filters=filters) + q = build_body(queries=queries, filters=filters) kargs["query"] = q["query"] - result = es().search(index=index if isinstance(index, str) else ",".join(index), - size=0, aggregations=aggr, runtime_mappings=runtime_mappings, **kargs - ) + + result = es().search( + index=index if isinstance(index, str) else ",".join(index), + size=0, + aggregations=aggr, + runtime_mappings=runtime_mappings, + **kargs, + ) if failure := result.get("_shards", {}).get("failures"): - raise Exception(f'Error on running aggregate search: {failure}') - yield from result['aggregations']['aggs']['buckets'] - after_key = result['aggregations']['aggs'].get('after_key') - if after_key: - yield from _elastic_aggregate(index, sources, queries, filters, aggregations, - runtime_mappings=runtime_mappings, after_key=after_key) + raise Exception(f"Error on running aggregate search: {failure}") + buckets = result["aggregations"]["aggs"]["buckets"] + after_key = result["aggregations"]["aggs"].get("after_key") -def _aggregate_results(index: Union[str, List[str]], axes: List[Axis], queries: Mapping[str, str], - filters: Optional[Mapping[str, Mapping]], aggregations: List[Aggregation]) -> Iterable[tuple]: - if not axes: + rows = [] + for bucket in buckets: + row = tuple(axis.get_value(bucket["key"]) for axis in axes) + row += (bucket["doc_count"],) + if aggregations: + row += tuple(a.get_value(bucket) for a in aggregations) + rows.append(row) + + return rows, after_key + + +def _aggregate_results( + index: Union[str, List[str]], + axes: List[Axis], + queries: dict[str, str] | None, + filters: dict[str, FilterSpec] | None, + aggregations: List[Aggregation], + after: dict[str, Any] | None = None, +): + + if not axes or len(axes) == 0: + # Path 1 # No axes, so return aggregations (or total count) only if aggregations: count, results = _bare_aggregate(index, queries, filters, aggregations) - yield (count,) + tuple(a.get_value(results) for a in aggregations) + rows = [(count,) + tuple(a.get_value(results) for a in aggregations)] else: - result = es().count(index=index if isinstance(index, str) else ",".join(index), - **build_body(queries=queries, filters=filters)) - yield result['count'], + result = es().count( + index=index if isinstance(index, str) else ",".join(index), **build_body(queries=queries, filters=filters) + ) + rows = [(result["count"],)] + yield rows, None + elif any(ax.field == "_query" for ax in axes): + + # Path 2 + # We cannot run the aggregation for multiple queries at once, so we loop over queries + # and recursively call _aggregate_results with one query at a time (which then uses path 3). + if queries is None: + raise ValueError("Queries must be specified when aggregating by query") # Strip off _query axis and run separate aggregation for each query i = [ax.field for ax in axes].index("_query") - _axes = axes[:i] + axes[(i+1):] - for label, query in queries.items(): - for result_tuple in _aggregate_results(index, _axes, {label: query}, filters, aggregations): + _axes = axes[:i] + axes[(i + 1) :] + + query_items = list(queries.items()) + for label, query in query_items: + last_query = label == query_items[-1][0] + + if after is not None and "_query" in after: + # after is a dict with the aggregation values from which to continue + # pagination. Since we loop over queries, we add the _query value. + # Then after continuing from the right query, we remove this _query + # key so that the after dict is as elastic expects it + if after.get("_query") != label: + continue + after.pop("_query", None) + + for rows, after_buckets in _aggregate_results(index, _axes, {label: query}, filters, aggregations, after=after): + after_buckets = copy.deepcopy(after_buckets) + # insert label into the right position on the result tuple - yield result_tuple[:i] + (label,) + result_tuple[i:] + rows = [result_tuple[:i] + (label,) + result_tuple[i:] for result_tuple in rows] + + if after_buckets is None: + # if there are no buckets left for this query, we check if this is the last query. + # If not, we need to return the _query value to ensure pagination continues from this query + if not last_query: + after_buckets = {"_query": label} + else: + # if there are buckets left, we add the _query value to ensure pagination continues from this query + after_buckets["_query"] = label + yield rows, after_buckets + + # after only applies to the first query + after = None + else: - # Run an aggregation with one or more axes + # Path 3 + # Run an aggregation with one or more axes. If after is not None, we continue from there. sources = [axis.query() for axis in axes] runtime_mappings = _combine_mappings(axis.runtime_mappings() for axis in axes) - for bucket in _elastic_aggregate(index, sources, queries, filters, aggregations, runtime_mappings): - row = tuple(axis.get_value(bucket['key']) for axis in axes) - row += (bucket['doc_count'], ) - if aggregations: - row += tuple(a.get_value(bucket) for a in aggregations) - yield row + rows, after = _elastic_aggregate(index, sources, axes, queries, filters, aggregations, runtime_mappings, after) + yield rows, after -def query_aggregate(index: Union[str, List[str]], axes: Sequence[Axis] = None, aggregations: Sequence[Aggregation] = None, *, - queries: Union[Mapping[str, str], Sequence[str]] = None, - filters: Mapping[str, Mapping] = None) -> AggregateResult: + if after is not None: + for rows, after in _aggregate_results(index, axes, queries, filters, aggregations, after): + yield rows, after + + +def query_aggregate( + index: str | list[str], + axes: list[Axis] | None = None, + aggregations: list[Aggregation] | None = None, + *, + queries: dict[str, str] | None = None, + filters: dict[str, FilterSpec] | None = None, + after: dict[str, Any] | None = None, +) -> AggregateResult: """ Conduct an aggregate query. Note that interval queries also yield zero counts for intervening keys without value, @@ -199,15 +289,39 @@ def query_aggregate(index: Union[str, List[str]], axes: Sequence[Axis] = None, a """ if axes and len([x.field == "_query" for x in axes[1:]]) > 1: raise ValueError("Only one aggregation axis may be by query") - fields = get_fields(index) + + all_fields: dict[str, Field] = dict() + indices = index if isinstance(index, list) else [index] + for index in indices: + index_fields = get_fields(index) + for field_name, field in index_fields.items(): + if field_name not in all_fields: + all_fields[field_name] = field + else: + if field.type != all_fields[field_name].type: + raise ValueError(f"Type of {field_name} is not the same in all indices") + all_fields.update(get_fields(index)) + if not axes: axes = [] for axis in axes: - axis.ftype = "_query" if axis.field == "_query" else fields[axis.field]['type'] + axis.ftype = "_query" if axis.field == "_query" else all_fields[axis.field].type if not aggregations: aggregations = [] for aggregation in aggregations: - aggregation.ftype = fields[aggregation.field]['type'] - queries = _normalize_queries(queries) - data = list(_aggregate_results(index, axes, queries, filters, aggregations)) - return AggregateResult(axes, aggregations, data, count_column="n", ) + aggregation.ftype = all_fields[aggregation.field].type + + # We get the rows in sets of queries * buckets, and if there are queries or buckets left, + # the last_after value serves as a pagination cursor. Once we have > [stop_after] rows, + # we return the data and the last_after cursor. If the user needs to collect the rest, + # they need to paginate + stop_after = 1000 + gen = _aggregate_results(indices, axes, queries, filters, aggregations, after) + data = list() + last_after = None + for rows, after in gen: + data += rows + last_after = after + if len(data) > stop_after: + gen.close() + return AggregateResult(axes, aggregations, data, count_column="n", after=last_after) diff --git a/amcat4/api/__init__.py b/amcat4/api/__init__.py index 57b0427..b6f9c9f 100644 --- a/amcat4/api/__init__.py +++ b/amcat4/api/__init__.py @@ -1,12 +1,29 @@ """AmCAT4 API.""" -from fastapi import FastAPI +from contextlib import asynccontextmanager +import logging +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware +from fastapi.middleware.gzip import GZipMiddleware from amcat4.api.index import app_index from amcat4.api.info import app_info from amcat4.api.query import app_query from amcat4.api.users import app_users +from amcat4.api.multimedia import app_multimedia +from amcat4.api.preprocessing import app_preprocessing +from amcat4.preprocessing.processor import start_processors + + +@asynccontextmanager +async def lifespan(app: FastAPI): + try: + start_processors() + except: + logging.exception("Error on initializing preprocessing") + yield + app = FastAPI( title="AmCAT4", @@ -15,19 +32,25 @@ dict(name="users", description="Endpoints for user management"), dict(name="index", description="Endpoints to create, list, and delete indices; and to add or modify documents"), dict(name="query", description="Endpoints to list or query documents or run aggregate queries"), - dict(name='middlecat', description="MiddleCat authentication"), + dict(name="middlecat", description="MiddleCat authentication"), dict(name="annotator users", description="Annotator module endpoints for user management"), - dict(name="annotator codingjob", - description="Annotator module endpoints for creating and managing annotator codingjobs, " - "and the core process of getting units and posting annotations"), + dict( + name="annotator codingjob", + description="Annotator module endpoints for creating and managing annotator codingjobs, " + "and the core process of getting units and posting annotations", + ), dict(name="annotator guest", description="Annotator module endpoints for unregistered guests"), - ] - + dict(name="multimedia", description="Endpoints for multimedia support"), + dict(name="preprocessing", description="Endpoints for preprocessing support"), + ], + lifespan=lifespan, ) app.include_router(app_info) app.include_router(app_users) app.include_router(app_index) app.include_router(app_query) +app.include_router(app_multimedia) +app.include_router(app_preprocessing) app.add_middleware( CORSMiddleware, allow_origins=["*"], @@ -35,3 +58,12 @@ allow_methods=["*"], allow_headers=["*"], ) +app.add_middleware(GZipMiddleware, minimum_size=1000) + + +@app.exception_handler(ValueError) +async def value_error_exception_handler(request: Request, exc: ValueError): + return JSONResponse( + status_code=400, + content={"message": str(exc)}, + ) diff --git a/amcat4/api/auth.py b/amcat4/api/auth.py index ae87683..94327e4 100644 --- a/amcat4/api/auth.py +++ b/amcat4/api/auth.py @@ -1,4 +1,6 @@ -"""Helper methods for authentication.""" +"""Helper methods for authentication and authorization.""" + +from argparse import ONE_OR_MORE import functools import logging from datetime import datetime @@ -6,13 +8,14 @@ import requests from authlib.common.errors import AuthlibBaseError from authlib.jose import jwt -from fastapi import HTTPException -from fastapi.params import Depends +from fastapi import HTTPException, Depends from fastapi.security import OAuth2PasswordBearer from starlette.status import HTTP_401_UNAUTHORIZED +from amcat4.models import FieldSpec from amcat4.config import get_settings, AuthOptions -from amcat4.index import Role, get_role, get_global_role +from amcat4.index import ADMIN_USER, GUEST_USER, Role, get_role, get_global_role +from amcat4.fields import get_fields oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/token", auto_error=False) @@ -35,12 +38,12 @@ def verify_token(token: str) -> dict: raises a InvalidToken exception if the token could not be validated """ payload = decode_middlecat_token(token) - if missing := {'email', 'resource', 'exp'} - set(payload.keys()): + if missing := {"email", "resource", "exp"} - set(payload.keys()): raise InvalidToken(f"Invalid token, missing keys {missing}") now = int(datetime.now().timestamp()) - if payload['exp'] < now: + if payload["exp"] < now: raise InvalidToken("Token expired") - if payload['resource'] != get_settings().host: + if payload["resource"] != get_settings().host: raise InvalidToken(f"Wrong host! {payload['resource']} != {get_settings().host}") return payload @@ -52,7 +55,7 @@ def decode_middlecat_token(token: str) -> dict: url = get_settings().middlecat_url if not url: raise InvalidToken("No middlecat defined, cannot decrypt middlecat token") - public_key = get_middlecat_config(url)['public_key'] + public_key = get_middlecat_config(url)["public_key"] try: return jwt.decode(token, public_key) except AuthlibBaseError as e: @@ -76,8 +79,10 @@ def check_global_role(user: str, required_role: Role, raise_error=True): if global_role and global_role >= required_role: return global_role if raise_error: - raise HTTPException(status_code=401, detail=f"User {user} does not have global " - f"{required_role.name.title()} permissions on this instance") + raise HTTPException( + status_code=401, + detail=f"User {user} does not have global " f"{required_role.name.title()} permissions on this instance", + ) else: return False @@ -96,13 +101,80 @@ def check_role(user: str, required_role: Role, index: str, required_global_role: return get_role(index, user) # Global role check was false, so now check local role actual_role = get_role(index, user) + if get_settings().auth == AuthOptions.no_auth: return actual_role elif actual_role and actual_role >= required_role: return actual_role else: - raise HTTPException(status_code=401, detail=f"User {user} does not have " - f"{required_role.name.title()} permissions on index {index}") + raise HTTPException( + status_code=401, + detail=f"User {user} does not have " f"{required_role.name.title()} permissions on index {index}", + ) + + +def check_fields_access(index: str, user: str, fields: list[FieldSpec]) -> None: + """Check if the given user is allowed to query the given fields and snippets on the given index. + + :param index: The index to check the role on + :param user: The email address of the authenticated user + :param fields: The fields to check + :param snippets: The snippets to check + :return: Nothing. Throws HTTPException if the user is not allowed to query the given fields and snippets. + """ + + role = get_role(index, user) + if role is None: + raise HTTPException( + status_code=401, + detail=f"User {user} does not have a role on index {index}", + ) + if role >= Role.READER: + return None + if fields is None: + return None + + # after this, we know the user is a metareader, so we need to check metareader_access + index_fields = get_fields(index) + for field in fields: + if field.name not in index_fields: + # might be better to raise an error here, but since we want to support querying multiple + # indices at once, this allows the user to query fields that do not exist on all indices + continue + metareader = index_fields[field.name].metareader + + if metareader.access == "read": + continue + elif metareader.access == "snippet" and metareader.max_snippet is not None: + if metareader.max_snippet is None: + max_params_msg = "" + else: + max_params_msg = ( + "Can only read snippet with max parameters:" + f" nomatch_chars={metareader.max_snippet.nomatch_chars}" + f", max_matches={metareader.max_snippet.max_matches}" + f", match_chars={metareader.max_snippet.match_chars}" + ) + if field.snippet is None: + # if snippet is not specified, the whole field is requested + raise HTTPException( + status_code=401, detail=f"METAREADER cannot read {field} on index {index}. {max_params_msg}" + ) + + valid_nomatch_chars = field.snippet.nomatch_chars <= metareader.max_snippet.nomatch_chars + valid_max_matches = field.snippet.max_matches <= metareader.max_snippet.max_matches + valid_match_chars = field.snippet.match_chars <= metareader.max_snippet.match_chars + valid = valid_nomatch_chars and valid_max_matches and valid_match_chars + if not valid: + raise HTTPException( + status_code=401, + detail=f"The requested snippet of {field.name} on index {index} is too long. {max_params_msg}", + ) + else: + raise HTTPException( + status_code=401, + detail=f"METAREADER cannot read {field.name} on index {index}", + ) async def authenticated_user(token: str = Depends(oauth2_scheme)) -> str: @@ -110,21 +182,25 @@ async def authenticated_user(token: str = Depends(oauth2_scheme)) -> str: auth = get_settings().auth if token is None: if auth == AuthOptions.no_auth: - return "admin" + return ADMIN_USER elif auth == AuthOptions.allow_guests: - return "guest" + return GUEST_USER else: - raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, - detail="This instance has no guest access, please provide a valid bearer token") + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="This instance has no guest access, please provide a valid bearer token", + ) try: - user = verify_token(token)['email'] + user = verify_token(token)["email"] except Exception: logging.exception("Login failed") raise HTTPException(status_code=401, detail="Invalid token") if auth == AuthOptions.authorized_users_only: if get_global_role(user) is None: - raise HTTPException(status_code=401, - detail=f"The user {user} is not authorized to access this AmCAT instance") + raise HTTPException( + status_code=401, + detail=f"The user {user} is not authorized to access this AmCAT instance", + ) return user diff --git a/amcat4/api/index.py b/amcat4/api/index.py index fab1791..dd8b34b 100644 --- a/amcat4/api/index.py +++ b/amcat4/api/index.py @@ -1,34 +1,25 @@ """API Endpoints for document and index management.""" + from http import HTTPStatus -from typing import List, Literal, Mapping, Optional +from re import U +from typing import Annotated, Any, Literal import elasticsearch from elastic_transport import ApiError -from fastapi import APIRouter, HTTPException, Response, status -from fastapi.params import Body, Depends -from pydantic import BaseModel, ConfigDict +from fastapi import APIRouter, HTTPException, Response, status, Depends, Body +from pydantic import BaseModel +from datetime import datetime -from amcat4 import elastic, index +from amcat4 import index, fields as index_fields from amcat4.api.auth import authenticated_user, authenticated_writer, check_role -from amcat4.api.common import py2dict -from amcat4.index import ( - Index, - IndexDoesNotExist, - Role, - get_global_role, - get_index, - get_role, - list_known_indices, - list_users, -) -from amcat4.index import refresh_index as es_refresh_index + from amcat4.index import refresh_system_index, remove_role, set_role +from amcat4.fields import field_values, field_stats +from amcat4.models import CreateField, FieldType, UpdateField app_index = APIRouter(prefix="/index", tags=["index"]) -RoleType = Literal[ - "ADMIN", "WRITER", "READER", "METAREADER", "admin", "writer", "reader", "metareader" -] +RoleType = Literal["ADMIN", "WRITER", "READER", "METAREADER"] @app_index.get("/") @@ -39,48 +30,51 @@ def index_list(current_user: str = Depends(authenticated_user)): Returns a list of dicts containing name, role, and guest attributes """ - def index_to_dict(ix: Index) -> dict: - ix = ix._asdict() - ix["guest_role"] = ix["guest_role"] and ix["guest_role"].name - del ix["roles"] - return ix + def index_to_dict(ix: index.Index) -> dict: + ix_dict = ix._asdict() + guest_role_int = ix_dict.get("guest_role", 0) + + ix_dict = dict( + id=ix_dict["id"], + name=ix_dict["name"], + guest_role=index.Role(guest_role_int).name, + description=ix_dict.get("description", ""), + archived=ix_dict.get("archived", ""), + ) + return ix_dict - return [index_to_dict(ix) for ix in list_known_indices(current_user)] + return [index_to_dict(ix) for ix in index.list_known_indices(current_user)] class NewIndex(BaseModel): """Form to create a new index.""" id: str - guest_role: Optional[RoleType] = None - name: Optional[str] = None - description: Optional[str] = None + name: str | None = None + guest_role: RoleType | None = None + description: str | None = None @app_index.post("/", status_code=status.HTTP_201_CREATED) -def create_index( - new_index: NewIndex, current_user: str = Depends(authenticated_writer) -): +def create_index(new_index: NewIndex, current_user: str = Depends(authenticated_writer)): """ Create a new index, setting the current user to admin (owner). POST data should be json containing name and optional guest_role """ - guest_role = new_index.guest_role and Role[new_index.guest_role.upper()] + guest_role = new_index.guest_role and index.Role[new_index.guest_role.upper()] try: index.create_index( new_index.id, guest_role=guest_role, name=new_index.name, description=new_index.description, - admin=current_user, + admin=current_user if current_user != "_admin" else None, ) except ApiError as e: raise HTTPException( status_code=400, - detail=dict( - info=f"Error on creating index: {e}", message=e.message, body=e.body - ), + detail=dict(info=f"Error on creating index: {e}", message=e.message, body=e.body), ) @@ -88,23 +82,10 @@ def create_index( class ChangeIndex(BaseModel): """Form to update an existing index.""" - guest_role: Optional[ - Literal[ - "ADMIN", - "WRITER", - "READER", - "METAREADER", - "admin", - "writer", - "reader", - "metareader", - "NONE", - "none", - ] - ] = "None" - name: Optional[str] = None - description: Optional[str] = None - summary_field: Optional[str] = None + name: str | None = None + description: str | None = None + guest_role: Literal["WRITER", "READER", "METAREADER", "NONE"] | None = None + archive: bool | None = None @app_index.put("/{ix}") @@ -116,21 +97,23 @@ def modify_index(ix: str, data: ChangeIndex, user: str = Depends(authenticated_u User needs admin rights on the index """ - check_role(user, Role.ADMIN, ix) - guest_role, remove_guest_role = None, False - if data.guest_role: - role = data.guest_role.upper() - if role == "NONE": - remove_guest_role = True - else: - guest_role = Role[role] + check_role(user, index.Role.ADMIN, ix) + guest_role = index.GuestRole[data.guest_role] if data.guest_role is not None else None + archived = None + if data.archive is not None: + d = index.get_index(ix) + is_archived = d.archived is not None and d.archived != "" + if is_archived != data.archive: + archived = str(datetime.now()) if data.archive else "" + index.modify_index( ix, name=data.name, description=data.description, guest_role=guest_role, - remove_guest_role=remove_guest_role, - summary_field=data.summary_field, + archived=archived, + # remove_guest_role=remove_guest_role, + # unarchive=unarchive, ) refresh_system_index() @@ -141,65 +124,87 @@ def view_index(ix: str, user: str = Depends(authenticated_user)): View the index. """ try: - role = check_role(user, Role.METAREADER, ix, required_global_role=Role.WRITER) - d = get_index(ix)._asdict() - d["user_role"] = role and role.name - d["guest_role"] = d["guest_role"].name if d.get("guest_role") else None + role = check_role(user, index.Role.METAREADER, ix, required_global_role=index.Role.WRITER) + d = index.get_index(ix)._asdict() + d["user_role"] = role.name + d["guest_role"] = index.GuestRole(d.get("guest_role", 0)).name + d["description"] = d.get("description", "") or "" + d["name"] = d.get("name", "") or "" return d - except IndexDoesNotExist: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=f"Index {ix} does not exist" - ) + except index.IndexDoesNotExist: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Index {ix} does not exist") -@app_index.delete( - "/{ix}", status_code=status.HTTP_204_NO_CONTENT, response_class=Response -) -def delete_index(ix: str, user: str = Depends(authenticated_user)): - """Delete the index.""" - check_role(user, Role.ADMIN, ix) - index.delete_index(ix) +@app_index.post("/{ix}/archive", status_code=status.HTTP_204_NO_CONTENT, response_class=Response) +def archive_index( + ix: str, + archived: Annotated[bool, Body(description="Boolean for setting archived to true or false")], + user: str = Depends(authenticated_user), +): + """Archive or unarchive the index. When an index is archived, it restricts usage, and adds a timestamp for when + it was archived. An index can only be deleted if it has been archived for a specific amount of time.""" + check_role(user, index.Role.ADMIN, ix) + try: + d = index.get_index(ix) + is_archived = d.archived is not None + if is_archived == archived: + return + archived_date = str(datetime.now()) if archived else None + index.modify_index(ix, archived=archived_date) + except index.IndexDoesNotExist: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Index {ix} does not exist") -class Document(BaseModel): - """Form to create (upload) a new document.""" - title: str - date: str - text: str - url: Optional[str] = None - model_config = ConfigDict(extra="allow") +@app_index.delete("/{ix}", status_code=status.HTTP_204_NO_CONTENT, response_class=Response) +def delete_index(ix: str, user: str = Depends(authenticated_user)): + """Delete the index.""" + check_role(user, index.Role.ADMIN, ix) + try: + index.delete_index(ix) + except index.IndexDoesNotExist: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Index {ix} does not exist") @app_index.post("/{ix}/documents", status_code=status.HTTP_201_CREATED) def upload_documents( ix: str, - documents: List[Document] = Body(None, description="The documents to upload"), - columns: Optional[Mapping[str, str]] = Body( - None, description="Optional Specification of field (column) types" - ), + documents: Annotated[list[dict[str, Any]], Body(description="The documents to upload")], + fields: Annotated[ + dict[str, FieldType | CreateField] | None, + Body( + description="If a field in documents does not yet exist, you can create it on the spot. " + "If you only need to specify the type, and use the default settings, " + "you can use the short form: {field: type}" + ), + ] = None, + operation: Annotated[ + Literal["update", "create"], + Body( + description="The operation to perform. Default is create, which ignores any documents that already exist. " + "The 'update' operation behaves as an upsert (create or update). If an identical document (or document with " + "identical identifiers) already exists, the uploaded fields will be created or overwritten. If there are fields " + "in the original document that are not in the uploaded document, they will NOT be removed. since update is destructive " + "it requires admin rights." + ), + ] = "create", user: str = Depends(authenticated_user), ): """ - Upload documents to this server. - - JSON payload should contain a `documents` key, and may contain a `columns` key: - { - "documents": [{"title": .., "date": .., "text": .., ...}, ...], - "columns": {: , ...} - } - Returns a list of ids for the uploaded documents + Upload documents to this server. Returns a list of ids for the uploaded documents """ - check_role(user, Role.WRITER, ix) - documents = [py2dict(doc) for doc in documents] - return elastic.upload_documents(ix, documents, columns) + if operation == "create": + check_role(user, index.Role.WRITER, ix) + else: + check_role(user, index.Role.ADMIN, ix) + return index.upload_documents(ix, documents, fields, operation) @app_index.get("/{ix}/documents/{docid}") def get_document( ix: str, docid: str, - fields: Optional[str] = None, + fields: str | None = None, user: str = Depends(authenticated_user), ): """ @@ -208,12 +213,12 @@ def get_document( GET request parameters: fields - Comma separated list of fields to return (default: all fields) """ - check_role(user, Role.READER, ix) + check_role(user, index.Role.READER, ix) kargs = {} if fields: kargs["_source"] = fields try: - return elastic.get_document(ix, docid, **kargs) + return index.get_document(ix, docid, **kargs) except elasticsearch.exceptions.NotFoundError: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -237,9 +242,9 @@ def update_document( PUT request body should be a json {field: value} mapping of fields to update """ - check_role(user, Role.WRITER, ix) + check_role(user, index.Role.WRITER, ix) try: - elastic.update_document(ix, docid, update) + index.update_document(ix, docid, update) except elasticsearch.exceptions.NotFoundError: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -254,9 +259,9 @@ def update_document( ) def delete_document(ix: str, docid: str, user: str = Depends(authenticated_user)): """Delete this document.""" - check_role(user, Role.WRITER, ix) + check_role(user, index.Role.WRITER, ix) try: - elastic.delete_document(ix, docid) + index.delete_document(ix, docid) except elasticsearch.exceptions.NotFoundError: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -264,36 +269,77 @@ def delete_document(ix: str, docid: str, user: str = Depends(authenticated_user) ) +@app_index.post("/{ix}/fields") +def create_fields( + ix: str, + fields: Annotated[ + dict[str, FieldType | CreateField], + Body( + description="Either a dictionary that maps field names to field specifications" + "({field: {type: 'text', identifier: True }}), " + "or a simplified version that only specifies the type ({field: type})" + ), + ], + user: str = Depends(authenticated_user), +): + """ + Create fields + """ + check_role(user, index.Role.WRITER, ix) + index_fields.create_fields(ix, fields) + return "", HTTPStatus.NO_CONTENT + + @app_index.get("/{ix}/fields") -def get_fields(ix: str, user=Depends(authenticated_user)): +def get_fields(ix: str, user: str = Depends(authenticated_user)): """ Get the fields (columns) used in this index. Returns a json array of {name, type} objects """ - check_role(user, Role.METAREADER, ix) - indices = ix.split(",") - return elastic.get_fields(indices) + check_role(user, index.Role.METAREADER, ix) + return index.get_fields(ix) -@app_index.post("/{ix}/fields") -def set_fields( - ix: str, body: dict = Body(...), user: str = Depends(authenticated_user) +@app_index.put("/{ix}/fields") +def update_fields( + ix: str, fields: Annotated[dict[str, UpdateField], Body(description="")], user: str = Depends(authenticated_user) ): """ - Set the field types used in this index. - - POST body should be a dict of {field: type} or {field: {type: type, meta: meta}} + Update the field settings """ - check_role(user, Role.WRITER, ix) - elastic.set_fields(ix, body) + check_role(user, index.Role.WRITER, ix) + + index_fields.update_fields(ix, fields) return "", HTTPStatus.NO_CONTENT @app_index.get("/{ix}/fields/{field}/values") -def get_values(ix: str, field: str, _=Depends(authenticated_user)): - """Get the fields (columns) used in this index.""" - return elastic.get_values(ix, field, size=100) +def get_field_values(ix: str, field: str, user: str = Depends(authenticated_user)): + """ + Get unique values for a specific field. Should mainly/only be used for tag fields. + Main purpose is to provide a list of values for a dropdown menu. + + TODO: at the moment 'only' returns top 2000 values. Currently throws an + error if there are more than 2000 unique values. We can increase this limit, but + there should be a limit. Querying could be an option, but not sure if that is + efficient, since elastic has to aggregate all values first. + """ + check_role(user, index.Role.READER, ix) + values = field_values(ix, field, size=2001) + if len(values) > 2000: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Field {field} has more than 2000 unique values", + ) + return values + + +@app_index.get("/{ix}/fields/{field}/stats") +def get_field_stats(ix: str, field: str, user: str = Depends(authenticated_user)): + """Get statistics for a specific value. Only works for numeric (incl date) fields.""" + check_role(user, index.Role.READER, ix) + return field_stats(ix, field) @app_index.get("/{ix}/users") @@ -303,17 +349,17 @@ def list_index_users(ix: str, user: str = Depends(authenticated_user)): Allowed for global admin and local readers """ - if get_global_role(user) != Role.ADMIN: - check_role(user, Role.READER, ix) - return [{"email": u, "role": r.name} for (u, r) in list_users(ix).items()] + if index.get_global_role(user) != index.Role.ADMIN: + check_role(user, index.Role.READER, ix) + return [{"email": u, "role": r.name} for (u, r) in index.list_users(ix).items()] def _check_can_modify_user(ix, user, target_user, target_role): - if get_global_role(user) != Role.ADMIN: + if index.get_global_role(user) != index.Role.ADMIN: required_role = ( - Role.ADMIN - if (target_role == Role.ADMIN or get_role(ix, target_user) == Role.ADMIN) - else Role.WRITER + index.Role.ADMIN + if (target_role == index.Role.ADMIN or index.get_role(ix, target_user) == index.Role.ADMIN) + else index.Role.WRITER ) check_role(user, required_role, ix) @@ -331,7 +377,7 @@ def add_index_users( To create regular users you need WRITER permission. To create ADMIN users, you need ADMIN permission. Global ADMINs can always add users. """ - r = Role[role] + r = index.Role[role] _check_can_modify_user(ix, user, email, r) set_role(ix, email, r) return {"user": email, "index": ix, "role": r.name} @@ -350,7 +396,7 @@ def modify_index_user( This requires WRITER rights on the index or global ADMIN rights. If changing a user from or to ADMIN, it requires (local or global) ADMIN rights """ - r = Role[role] + r = index.Role[role] _check_can_modify_user(ix, user, email, r) set_role(ix, email, r) return {"user": email, "index": ix, "role": r.name} @@ -369,8 +415,6 @@ def remove_index_user(ix: str, email: str, user: str = Depends(authenticated_use return {"user": email, "index": ix, "role": None} -@app_index.get( - "/{ix}/refresh", status_code=status.HTTP_204_NO_CONTENT, response_class=Response -) +@app_index.get("/{ix}/refresh", status_code=status.HTTP_204_NO_CONTENT, response_class=Response) def refresh_index(ix: str): - es_refresh_index(ix) + index.refresh_index(ix) diff --git a/amcat4/api/multimedia.py b/amcat4/api/multimedia.py new file mode 100644 index 0000000..0ffd239 --- /dev/null +++ b/amcat4/api/multimedia.py @@ -0,0 +1,69 @@ +import itertools +from typing import Optional +from fastapi import APIRouter, Depends, HTTPException + +from amcat4 import index, multimedia +from amcat4.api.auth import authenticated_user, check_role +from minio.datatypes import Object +from minio.error import S3Error + +app_multimedia = APIRouter(prefix="/index/{ix}/multimedia", tags=["multimedia"]) + + +@app_multimedia.get("/presigned_get") +def presigned_get(ix: str, key: str, user: str = Depends(authenticated_user)): + check_role(user, index.Role.READER, ix) + try: + url = multimedia.presigned_get(ix, key) + obj = multimedia.stat_multimedia_object(ix, key) + return dict(url=url, content_type=(obj.content_type,), size=obj.size) + except S3Error as e: + if e.code == "NoSuchKey": + raise HTTPException(status_code=404, detail=f"multimedia file {key} not found") + raise HTTPException(status_code=404, detail=e.message) + + +@app_multimedia.get("/presigned_post") +def presigned_post(ix: str, user: str = Depends(authenticated_user)): + check_role(user, index.Role.WRITER, ix) + url, form_data = multimedia.presigned_post(ix) + return dict(url=url, form_data=form_data) + + +@app_multimedia.get("/list") +def list_multimedia( + ix: str, + n: int = 10, + prefix: Optional[str] = None, + start_after: Optional[str] = None, + recursive=False, + presigned_get=False, + metadata=False, + user: str = Depends(authenticated_user), +): + recursive = str(recursive).lower() == "true" + metadata = str(metadata).lower() == "true" + presigned_get = str(presigned_get).lower() == "true" + + def process(obj: Object): + if metadata and (not obj.is_dir) and obj.object_name: + obj = multimedia.stat_multimedia_object(ix, obj.object_name) + result: dict[str, object] = dict( + key=obj.object_name, + is_dir=obj.is_dir, + last_modified=obj.last_modified, + size=obj.size, + ) + if metadata: + result["metadata"] = (obj.metadata,) + result["content_type"] = (obj.content_type,) + + if presigned_get is True and not obj.is_dir: + if n > 10: + raise ValueError("Cannot provide presigned_get for more than 10 objects") + result["presigned_get"] = multimedia.presigned_get(ix, obj.object_name) + return result + + check_role(user, index.Role.READER, ix) + objects = multimedia.list_multimedia_objects(ix, prefix, start_after, recursive) + return [process(obj) for obj in itertools.islice(objects, n)] diff --git a/amcat4/api/preprocessing.py b/amcat4/api/preprocessing.py new file mode 100644 index 0000000..0c3f2f1 --- /dev/null +++ b/amcat4/api/preprocessing.py @@ -0,0 +1,79 @@ +import asyncio +import logging +from typing import Annotated, Literal +from fastapi import APIRouter, Body, Depends, HTTPException, Response, status + +from amcat4 import index +from amcat4.api.auth import authenticated_user, check_role +from amcat4.preprocessing.models import PreprocessingInstruction +from amcat4.index import ( + get_instruction, + get_instructions, + add_instruction, + reassign_preprocessing_errors, + start_preprocessor, + stop_preprocessor, +) +from amcat4.preprocessing.processor import get_counts, get_manager +from amcat4.preprocessing.task import get_tasks + +logger = logging.getLogger("amcat4.preprocessing") + +app_preprocessing = APIRouter(tags=["preprocessing"]) + + +@app_preprocessing.get("/preprocessing_tasks") +def list_tasks(): + return [t.model_dump() for t in get_tasks()] + + +@app_preprocessing.get("/index/{ix}/preprocessing") +def list_instructions(ix: str, user: str = Depends(authenticated_user)): + check_role(user, index.Role.READER, ix) + return get_instructions(ix) + + +@app_preprocessing.post("/index/{ix}/preprocessing", status_code=status.HTTP_204_NO_CONTENT, response_class=Response) +async def post_instruction(ix: str, instruction: PreprocessingInstruction, user: str = Depends(authenticated_user)): + check_role(user, index.Role.WRITER, ix) + add_instruction(ix, instruction) + + +@app_preprocessing.get("/index/{ix}/preprocessing/{field}") +async def get_instruction_details(ix: str, field: str, user: str = Depends(authenticated_user)): + check_role(user, index.Role.WRITER, ix) + i = get_instruction(ix, field) + if i is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Preprocessing instruction for field {field} on index {ix} not found", + ) + state = get_manager().get_status(ix, field) + counts = get_counts(ix, field) + return dict(instruction=i, status=state, counts=counts) + + +@app_preprocessing.get("/index/{ix}/preprocessing/{field}/status") +async def get_status(ix: str, field: str, user: str = Depends(authenticated_user)): + return dict(status=get_manager().get_status(ix, field)) + + +@app_preprocessing.post( + "/index/{ix}/preprocessing/{field}/status", status_code=status.HTTP_204_NO_CONTENT, response_class=Response +) +async def set_status( + ix: str, + field: str, + user: str = Depends(authenticated_user), + action: Literal["Start", "Stop", "Reassign"] = Body(description="Status to set for this preprocessing task", embed=True), +): + check_role(user, index.Role.WRITER, ix) + current_status = get_manager().get_status(ix, field) + if action == "Start" and current_status in {"Unknown", "Error", "Stopped", "Done"}: + start_preprocessor(ix, field) + elif action == "Stop" and current_status in {"Active"}: + stop_preprocessor(ix, field) + elif action == "Reassign": + reassign_preprocessing_errors(ix, field) + else: + raise HTTPException(422, f"Cannot {action}, (status: {current_status}; field {ix}.{field})") diff --git a/amcat4/api/query.py b/amcat4/api/query.py index a923b06..eec895f 100644 --- a/amcat4/api/query.py +++ b/amcat4/api/query.py @@ -1,15 +1,19 @@ """API Endpoints for querying.""" -from typing import Dict, List, Optional, Any, Union, Iterable, Tuple, Literal +from re import search +from typing import Annotated, Dict, List, Optional, Any, Union, Iterable, Literal -from fastapi import APIRouter, HTTPException, status, Request, Query, Depends, Response -from fastapi.params import Body +from fastapi import APIRouter, HTTPException, status, Depends, Response, Body +from pydantic import InstanceOf from pydantic.main import BaseModel from amcat4 import query, aggregate from amcat4.aggregate import Axis, Aggregation -from amcat4.api.auth import authenticated_user, check_role -from amcat4.index import Role +from amcat4.api.auth import authenticated_user, check_fields_access +from amcat4.config import AuthOptions, get_settings +from amcat4.fields import create_fields +from amcat4.index import Role, get_role, get_fields +from amcat4.models import FieldSpec, FilterSpec, FilterValue, SortSpec from amcat4.query import update_tag_query app_query = APIRouter(prefix="/index", tags=["query"]) @@ -32,234 +36,216 @@ class QueryResult(BaseModel): meta: QueryMeta -def _check_query_role( - indices: List[str], user: str, fields: List[str], highlight: bool -): - if (not fields) or ("text" in fields) or (highlight): - role = Role.READER - else: - role = Role.METAREADER - for ix in indices: - check_role(user, role, ix) - - -@app_query.get("/{index}/documents", response_model=QueryResult) -def get_documents( - index: str, - request: Request, - q: List[str] = Query( - None, - description="Elastic query string. " - "Argument may be repeated for multiple queries (treated as OR)", - ), - sort: str = Query( - None, - description="Comma separated list of fields to sort on", - examples="id,date:desc", - pattern=r"\w+(:desc)?(,\w+(:desc)?)*", - ), - fields: str = Query( - None, - description="Comma separated list of fields to return", - pattern=r"\w+(,\w+)*", - ), - per_page: int = Query(None, description="Number of results per page"), - page: int = Query(None, description="Page to fetch"), - scroll: str = Query( - None, - description="Create a new scroll_id to download all results in subsequent calls", - examples="3m", - ), - scroll_id: str = Query(None, description="Get the next batch from this scroll id"), - highlight: bool = Query(False, description="add highlight tags "), - annotations: bool = Query( - False, - description="if true, also return _annotations " - "with query matches as annotations", - ), - user: str = Depends(authenticated_user), -): +def get_or_validate_allowed_fields( + user: str, indices: Iterable[str], fields: list[FieldSpec] | None = None +) -> list[FieldSpec]: """ - List (possibly filtered) documents in this index. - - Any additional GET parameters are interpreted as filters, and can be - field=value for a term query, or field__xxx=value for a range query, with xxx in gte, gt, lte, lt - Note that dates can use relative queries, see elasticsearch 'date math' - In case of conflict between field names and (other) arguments, you may prepend a field name with __ - If your field names contain __, it might be better to use POST queries - - Returns a JSON object {data: [...], meta: {total_count, per_page, page_count, page|scroll_id}} + For any endpoint that returns field values, make sure the user only gets fields that + they are allowed to see. If fields is None, return all allowed fields. If fields is not None, + check whether the user can access the fields (If not, raise an error). """ - indices = index.split(",") - fields = fields and fields.split(",") - if not fields: - fields = ["date", "title", "url"] - _check_query_role(indices, user, fields, highlight) - args = {} - sort = sort and [ - {x.replace(":desc", ""): "desc"} if x.endswith(":desc") else x - for x in sort.split(",") - ] - known_args = ["page", "per_page", "scroll", "scroll_id", "highlight", "annotations"] - for name in known_args: - val = locals()[name] - if val: - args[name] = int(val) if name in ["page", "per_page"] else val - filters: Dict[str, Dict] = {} - for f, v in request.query_params.items(): - if f not in known_args + ["fields", "sort", "q"]: - if f.startswith("__"): - f = f[2:] - if "__" in f: # range query - (field, operator) = f.split("__") - if field not in filters: - filters[field] = {} - filters[field][operator] = v - else: # value query - if f not in filters: - filters[f] = {"values": []} - filters[f]["values"].append(v) - r = query.query_documents( - indices, fields=fields, queries=q, filters=filters, sort=sort, **args - ) - if r is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No results") - return r.as_dict() - - -FilterValue = Union[str, int] - - -class FilterSpec(BaseModel): - """Form for filter specification.""" - - values: Optional[List[FilterValue]] = None - gt: Optional[FilterValue] = None - lt: Optional[FilterValue] = None - gte: Optional[FilterValue] = None - lte: Optional[FilterValue] = None - exists: Optional[bool] = None - - -def _process_queries( - queries: Optional[Union[str, List[str], List[Dict[str, str]]]] = None -) -> Optional[dict]: + if not isinstance(user, str): + raise ValueError("User should be a string") + if not isinstance(indices, list): + raise ValueError("Indices should be a list") + if fields is not None and not isinstance(fields, list): + raise ValueError("Fields should be a list or None") + + no_auth = get_settings().auth == AuthOptions.no_auth + if fields is None: + if len(indices) > 1: + # this restrictions is needed, because otherwise we need to return all allowed fields taking + # into account the user's role for each index, and take the lowest possible access. + # this is error prone and complex, so best to just disallow it. Also, requesting all fields + # for multiple indices is probably not something we should support anyway + raise ValueError("Fields should be specified if multiple indices are given") + index_fields = get_fields(indices[0]) + role = get_role(indices[0], user) + allowed_fields: list[FieldSpec] = [] + for field in index_fields.keys(): + if role >= Role.READER or no_auth: + allowed_fields.append(FieldSpec(name=field)) + elif role == Role.METAREADER: + metareader = index_fields[field].metareader + if metareader.access == "read": + allowed_fields.append(FieldSpec(name=field)) + if metareader.access == "snippet": + allowed_fields.append(FieldSpec(name=field, snippet=metareader.max_snippet)) + else: + raise HTTPException( + status_code=401, + detail=f"User {user} does not have a role on index {indices[0]}", + ) + return allowed_fields + + for index in indices: + if not no_auth: + check_fields_access(index, user, fields) + return fields + + +def _standardize_queries(queries: str | list[str] | dict[str, str] | None = None) -> dict[str, str] | None: """Convert query json to dict format: {label1:query1, label2: query2} uses indices if no labels given.""" + if queries: # to dict format: {label1:query1, label2: query2} uses indices if no labels given if isinstance(queries, str): - queries = [queries] - if isinstance(queries, list): - queries = {str(i): q for i, q in enumerate(queries)} - return queries + return {"1": queries} + elif isinstance(queries, list): + return {str(i): q for i, q in enumerate(queries)} + elif isinstance(queries, dict): + return queries + return None -def _process_filters( - filters: Optional[ - Dict[str, Union[FilterValue, List[FilterValue], FilterSpec]] - ] = None -) -> Iterable[Tuple[str, dict]]: +def _standardize_filters( + filters: dict[str, FilterValue | list[FilterValue] | FilterSpec] | None = None +) -> dict[str, FilterSpec] | None: """Convert filters to dict format: {field: {values: []}}.""" if not filters: - return + return None + + f: dict[str, FilterSpec] = {} for field, filter_ in filters.items(): if isinstance(filter_, str): - filter_ = [filter_] - if isinstance(filter_, list): - yield field, {"values": filter_} + f[field] = FilterSpec(values=[filter_]) + elif isinstance(filter_, list): + f[field] = FilterSpec(values=filter_) elif isinstance(filter_, FilterSpec): - yield field, { - k: v for (k, v) in filter_.model_dump().items() if v is not None - } + f[field] = filter_ else: raise ValueError(f"Cannot parse filter: {filter_}") + return f + + +def _standardize_fieldspecs(fields: list[str | FieldSpec] | None = None) -> list[FieldSpec] | None: + """Convert fields to list of FieldSpecs.""" + if not fields: + return None + + f = [] + for field in fields: + if isinstance(field, str): + f.append(FieldSpec(name=field)) + elif isinstance(field, FieldSpec): + f.append(field) + else: + raise ValueError(f"Cannot parse field: {field}") + return f + + +def _standardize_sort(sort: str | list[str] | list[dict[str, SortSpec]] | None = None) -> list[dict[str, SortSpec]] | None: + """Convert sort to list of dicts.""" + + # TODO: sort cannot be right. that array around dict is useless + + if not sort: + return None + if isinstance(sort, str): + return [{sort: SortSpec(order="asc")}] + + sortspec: list[dict[str, SortSpec]] = [] + + for field in sort: + if isinstance(field, str): + sortspec.append({field: SortSpec(order="asc")}) + elif isinstance(field, dict): + sortspec.append(field) + else: + raise ValueError(f"Cannot parse sort: {sort}") + + return sortspec @app_query.post("/{index}/query", response_model=QueryResult) def query_documents_post( index: str, - queries: Optional[Union[str, List[str], Dict[str, str]]] = Body( - None, - description="Query/Queries to run. Value should be a single query string, a list of query strings, " - "or a dict of {'label': 'query'}", - ), - fields: Optional[List[str]] = Body( - None, description="List of fields to retrieve for each document" - ), - filters: Optional[ - Dict[str, Union[FilterValue, List[FilterValue], FilterSpec]] - ] = Body( - None, - description="Field filters, should be a dict of field names to filter specifications," - "which can be either a value, a list of values, or a FilterSpec dict", - ), - sort: Optional[Union[str, List[str], List[Dict[str, dict]]]] = Body( - None, - description="Sort by field name(s) or dict (see " - "https://www.elastic.co/guide/en/elasticsearch/reference/current/sort-search-results.html for dict format)", - examples={ - "simple": {"summary": "Sort by single field", "value": "'date'"}, - "multiple": { - "summary": "Sort by multiple fields", - "value": "['date', 'title']", + queries: Annotated[ + str | list[str] | dict[str, str] | None, + Body( + description="Query/Queries to run. Value should be a single query string, a list of query strings, " + "or a dict of {'label': 'query'}", + ), + ] = None, + fields: Annotated[ + list[str | FieldSpec] | None, + Body( + description="List of fields to retrieve for each document" + "In the list you can specify a fieldname, but also a FieldSpec dict." + "Using the FieldSpec allows you to request only a snippet of a field." + "fieldname[nomatch_chars;max_matches;match_chars]. 'matches' here refers to words from text queries. " + "If there is no query, the snippet is the first [nomatch_chars] characters. " + "If there is a query, snippets are returned for up to [max_matches] matches, with each match having [match_chars] " + "characters. If there are multiple matches, they are concatenated with ' ... '.", + openapi_examples={ + "simple": {"summary": "Retrieve single field", "value": '["title", "text", "date"]'}, + "text as snippet": { + "summary": "Retrieve the full title, but text only as snippet", + "value": '["title", {"name": "text", "snippet": {"nomatch_chars": 100}}]', + }, + "all allowed fields": { + "summary": "If fields is left empty, all fields that the user is allowed to see are returned", + }, }, - "dict": { - "summary": "Use dict to specify sort options", - "value": " [{'date': {'order':'desc'}}]", + ), + ] = None, + filters: Annotated[ + dict[str, FilterValue | list[FilterValue] | FilterSpec] | None, + Body( + description="Field filters, should be a dict of field names to filter specifications," + "which can be either a value, a list of values, or a FilterSpec dict", + ), + ] = None, + sort: Annotated[ + str | list[str] | list[dict[str, SortSpec]] | None, + Body( + description="Sort by field name(s) or dict (see " + "https://www.elastic.co/guide/en/elasticsearch/reference/current/sort-search-results.html for dict format)", + openapi_examples={ + "simple": {"summary": "Sort by single field", "value": "'date'"}, + "multiple": { + "summary": "Sort by multiple fields", + "value": "['date', 'title']", + }, + "dict": { + "summary": "Use dict to specify sort options", + "value": " [{'date': {'order':'desc'}}]", + }, }, - }, - ), - per_page: Optional[int] = Body(10, description="Number of documents per page"), - page: Optional[int] = Body(0, description="Which page to retrieve"), - scroll: Optional[str] = Body( - None, - description="Scroll specification (e.g. '5m') to start a scroll request" - "This will return a scroll_id which should be passed to subsequent calls" - "(this is the advised way of scrolling through multiple pages of results)", - examples="5m", - ), - scroll_id: Optional[str] = Body( - None, description="Scroll id from previous response to continue scrolling" - ), - annotations: Optional[bool] = Body( - None, description="Return _annotations with query matches as annotations" - ), - highlight: Optional[Union[bool, Dict]] = Body( - None, - description="Highlight document. 'true' highlights whole document, see elastic docs for dict format" - "https://www.elastic.co/guide/en/elasticsearch/reference/7.17/highlighting.html", - ), - user=Depends(authenticated_user), + ), + ] = None, + per_page: Annotated[int, Body(le=200, description="Number of documents per page")] = 10, + page: Annotated[int, Body(description="Which page to retrieve")] = 0, + scroll: Annotated[ + str | None, + Body( + description="Scroll specification (e.g. '5m') to start a scroll request" + "This will return a scroll_id which should be passed to subsequent calls" + "(this is the advised way of scrolling through multiple pages of results)", + examples=["5m"], + ), + ] = None, + scroll_id: Annotated[str | None, Body(description="Scroll id from previous response to continue scrolling")] = None, + highlight: Annotated[bool, Body(description="If true, highlight fields")] = False, + user: str = Depends(authenticated_user), ): """ List or query documents in this index. Returns a JSON object {data: [...], meta: {total_count, per_page, page_count, page|scroll_id}} """ - # TODO check user rights on index - # Standardize fields, queries and filters to their most versatile format indices = index.split(",") - if fields: - # to array format: fields: [field1, field2] - if isinstance(fields, str): - fields = [fields] - else: - fields = ["date", "title", "url"] - _check_query_role(indices, user, fields, highlight is not None) - - queries = _process_queries(queries) - filters = dict(_process_filters(filters)) + fieldspecs = get_or_validate_allowed_fields(user, indices, _standardize_fieldspecs(fields)) r = query.query_documents( indices, - queries=queries, - filters=filters, - fields=fields, - sort=sort, + queries=_standardize_queries(queries), + filters=_standardize_filters(filters), + fields=fieldspecs, + sort=_standardize_sort(sort), per_page=per_page, page=page, scroll_id=scroll_id, scroll=scroll, - annotations=annotations, highlight=highlight, ) if r is None: @@ -287,25 +273,24 @@ class AxisSpec(BaseModel): @app_query.post("/{index}/aggregate") def query_aggregate_post( index: str, - axes: Optional[List[AxisSpec]] = Body( - None, description="Axes to aggregate on (i.e. group by)" - ), - aggregations: Optional[List[AggregationSpec]] = Body( - None, description="Aggregate functions to compute" - ), - queries: Optional[Union[str, List[str], Dict[str, str]]] = Body( - None, - description="Query/Queries to run. Value should be a single query string, a list of query strings, " - "or a dict of queries {'label': 'query'}", - ), - filters: Optional[ - Dict[str, Union[FilterValue, List[FilterValue], FilterSpec]] - ] = Body( - None, - description="Field filters, should be a dict of field names to filter specifications," - "which can be either a value, a list of values, or a FilterSpec dict", - ), - _user=Depends(authenticated_user), + axes: Optional[List[AxisSpec]] = Body(None, description="Axes to aggregate on (i.e. group by)"), + aggregations: Optional[List[AggregationSpec]] = Body(None, description="Aggregate functions to compute"), + queries: Annotated[ + str | list[str] | dict[str, str] | None, + Body( + description="Query/Queries to run. Value should be a single query string, a list of query strings, " + "or a dict of {'label': 'query'}", + ), + ] = None, + filters: Annotated[ + dict[str, FilterValue | list[FilterValue] | FilterSpec] | None, + Body( + description="Field filters, should be a dict of field names to filter specifications," + "which can be either a value, a list of values, or a FilterSpec dict", + ), + ] = None, + after: Annotated[dict[str, Any] | None, Body(description="After cursor for pagination")] = None, + user: str = Depends(authenticated_user), ): """ Construct an aggregate query. @@ -321,64 +306,64 @@ def query_aggregate_post( # TODO check user rights on index indices = index.split(",") _axes = [Axis(**x.model_dump()) for x in axes] if axes else [] - _aggregations = ( - [Aggregation(**x.model_dump()) for x in aggregations] if aggregations else [] - ) - if not (_axes or _aggregations): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Aggregation needs at least one axis or aggregation", - ) - queries = _process_queries(queries) - filters = dict(_process_filters(filters)) + _aggregations = [Aggregation(**x.model_dump()) for x in aggregations] if aggregations else [] + results = aggregate.query_aggregate( - indices, _axes, _aggregations, queries=queries, filters=filters + indices, + _axes, + _aggregations, + queries=_standardize_queries(queries), + filters=_standardize_filters(filters), + after=after, ) + return { "meta": { "axes": [axis.asdict() for axis in results.axes], "aggregations": [a.asdict() for a in results.aggregations], + "after": results.after, }, "data": list(results.as_dicts()), } -@app_query.post( - "/{index}/tags_update", - status_code=status.HTTP_204_NO_CONTENT, - response_class=Response, -) +@app_query.post("/{index}/tags_update") def query_update_tags( index: str, - action: Literal["add", "remove"] = Body( - None, description="Action (add or remove) on tags" - ), + action: Literal["add", "remove"] = Body(None, description="Action (add or remove) on tags"), field: str = Body(None, description="Tag field to update"), tag: str = Body(None, description="Tag to add or remove"), - queries: Optional[Union[str, List[str], Dict[str, str]]] = Body( - None, - description="Query/Queries to run. Value should be a single query string, a list of query strings, " - "or a dict of {'label': 'query'}", - ), - filters: Optional[ - Dict[str, Union[FilterValue, List[FilterValue], FilterSpec]] - ] = Body( - None, - description="Field filters, should be a dict of field names to filter specifications," - "which can be either a value, a list of values, or a FilterSpec dict", - ), - ids: Optional[Union[str, List[str]]] = Body( - None, description="Document IDs of documents to update" - ), - _user=Depends(authenticated_user), + queries: Annotated[ + str | list[str] | dict[str, str] | None, + Body( + description="Query/Queries to run. Value should be a single query string, a list of query strings, " + "or a dict of {'label': 'query'}", + ), + ] = None, + filters: Annotated[ + dict[str, FilterValue | list[FilterValue] | FilterSpec] | None, + Body( + description="Field filters, should be a dict of field names to filter specifications," + "which can be either a value, a list of values, or a FilterSpec dict", + ), + ] = None, + ids: Optional[Union[str, List[str]]] = Body(None, description="Document IDs of documents to update"), + user: str = Depends(authenticated_user), ): """ Add or remove tags by query or by id """ indices = index.split(",") - queries = _process_queries(queries) - filters = dict(_process_filters(filters)) + for i in indices: + if get_role(i, user) < Role.WRITER: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=f"User {user} does not have permission to update tags on index {i}", + ) + if isinstance(ids, (str, int)): ids = [ids] - update_tag_query(indices, action, field, tag, queries, filters, ids) - return + update_result = update_tag_query( + indices, action, field, tag, _standardize_queries(queries), _standardize_filters(filters), ids + ) + return update_result diff --git a/amcat4/api/users.py b/amcat4/api/users.py index a5efd8a..7029a56 100644 --- a/amcat4/api/users.py +++ b/amcat4/api/users.py @@ -4,23 +4,23 @@ AmCAT4 can use either Basic or Token-based authentication. A client can request a token with basic authentication and store that token for future requests. """ + from typing import Literal, Optional from importlib.metadata import version -from fastapi import APIRouter, HTTPException, status, Response -from fastapi.params import Depends +from fastapi import APIRouter, HTTPException, status, Response, Depends from pydantic import BaseModel from pydantic.networks import EmailStr from amcat4 import index from amcat4.api.auth import authenticated_user, authenticated_admin, check_global_role from amcat4.config import get_settings, validate_settings -from amcat4.index import Role, set_global_role, get_global_role +from amcat4.index import ADMIN_USER, GUEST_USER, Role, set_global_role, get_global_role, user_exists app_users = APIRouter(tags=["users"]) -ROLE = Literal["ADMIN", "WRITER", "READER", "admin", "writer", "reader"] +ROLE = Literal["ADMIN", "WRITER", "READER", "NONE"] class UserForm(BaseModel): @@ -36,10 +36,10 @@ class ChangeUserForm(BaseModel): role: Optional[ROLE] = None -@app_users.post("/users/", status_code=status.HTTP_201_CREATED) +@app_users.post("/users", status_code=status.HTTP_201_CREATED) def create_user(new_user: UserForm, _=Depends(authenticated_admin)): """Create a new user.""" - if get_global_role(new_user.email) is not None: + if user_exists(new_user.email): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"User {new_user.email} already exists", @@ -74,7 +74,7 @@ def _get_user(email, current_user): if current_user != email: check_global_role(current_user, Role.WRITER) global_role = get_global_role(email) - if email in ("admin", "guest") or global_role is None: + if email in (ADMIN_USER, GUEST_USER) or global_role is Role.NONE: raise HTTPException(404, detail=f"User {email} unknown") else: return {"email": email, "role": global_role.name} @@ -83,15 +83,10 @@ def _get_user(email, current_user): @app_users.get("/users", dependencies=[Depends(authenticated_admin)]) def list_global_users(): """List all global users""" - return [ - {"email": email, "role": role.name} - for (email, role) in index.list_global_users().items() - ] + return [{"email": email, "role": role.name} for (email, role) in index.list_global_users().items()] -@app_users.delete( - "/users/{email}", status_code=status.HTTP_204_NO_CONTENT, response_class=Response -) +@app_users.delete("/users/{email}", status_code=status.HTTP_204_NO_CONTENT, response_class=Response) def delete_user(email: EmailStr, current_user: str = Depends(authenticated_user)): """ Delete the given user. @@ -104,16 +99,18 @@ def delete_user(email: EmailStr, current_user: str = Depends(authenticated_user) @app_users.put("/users/{email}") -def modify_user( - email: EmailStr, data: ChangeUserForm, _user=Depends(authenticated_admin) -): +def modify_user(email: EmailStr, data: ChangeUserForm, _user: str = Depends(authenticated_admin)): """ Modify the given user. Only admin can change users. """ - role = Role[data.role.upper()] - set_global_role(email, role) - return {"email": email, "role": role.name} + if data.role is None or data.role == "NONE": + set_global_role(email, None) + return {"email": email, "role": None} + else: + role = Role[data.role.upper()] + set_global_role(email, role) + return {"email": email, "role": role.name} @app_users.get("/config") diff --git a/amcat4/config.py b/amcat4/config.py index 1c6506c..adb8285 100644 --- a/amcat4/config.py +++ b/amcat4/config.py @@ -6,15 +6,18 @@ - A .env file, either in the current working directory or in a location specified by the AMCAT4_CONFIG_FILE environment variable """ + import functools from enum import Enum from pathlib import Path -from typing import Optional +from typing import Annotated, Any from class_doc import extract_docs_from_cls_obj from dotenv import load_dotenv from pydantic import model_validator, Field from pydantic_settings import BaseSettings, SettingsConfigDict +ENV_PREFIX = "amcat4_" + class AuthOptions(str, Enum): #: everyone (that can reach the server) can do anything they want @@ -37,70 +40,85 @@ def validate(cls, value: str): return f"{value} is not a valid authorization option. Choose one of {{{options}}}" -# As far as I know, there is no elegant built-in way to set to __doc__ of an enum? +# Set the __doc__ attribute of each AuthOptions enum member using extract_docs_from_cls_obj for field, doc in extract_docs_from_cls_obj(AuthOptions).items(): AuthOptions[field].__doc__ = "\n".join(doc) class Settings(BaseSettings): - env_file: Path = Field( - ".env", - description="Location of a .env file (if used) relative to working directory", - ) - host: str = Field( - "http://localhost:5000", - description="Host this instance is served at (needed for checking tokens)", - ) - - elastic_password: Optional[str] = Field( - None, - description=( - "Elasticsearch password. " - "This the password for the 'elastic' user when Elastic xpack security is enabled" + env_file: Annotated[ + Path, + Field( + description="Location of a .env file (if used) relative to working directory", + ), + ] = Path(".env") + host: Annotated[ + str, + Field( + description="Host this instance is served at (needed for checking tokens)", ), - ) + ] = "http://localhost:5000" - elastic_host: Optional[str] = Field( - None, - description=( - "Elasticsearch host. " - "Default: https://localhost:9200 if elastic_password is set, http://localhost:9200 otherwise" + elastic_password: Annotated[ + str | None, + Field( + description=( + "Elasticsearch password. " "This the password for the 'elastic' user when Elastic xpack security is enabled" + ) ), - ) + ] = None + + elastic_host: Annotated[ + str | None, + Field( + description=( + "Elasticsearch host. " + "Default: https://localhost:9200 if elastic_password is set, http://localhost:9200 otherwise" + ) + ), + ] = None + + elastic_verify_ssl: Annotated[ + bool | None, + Field( + description=( + "Elasticsearch verify SSL (only used if elastic_password is set). " "Default: True unless host is localhost)" + ), + ), + ] = None - elastic_verify_ssl: Optional[bool] = Field( - None, - description=( - "Elasticsearch verify SSL (only used if elastic_password is set). " - "Default: True unless host is localhost)" + system_index: Annotated[ + str, + Field( + description="Elasticsearch index to store authorization information in", ), - ) + ] = "amcat4_system" - system_index: str = Field( - "amcat4_system", - description="Elasticsearch index to store authorization information in", - ) + auth: Annotated[AuthOptions, Field(description="Do we require authorization?")] = AuthOptions.no_auth - auth: AuthOptions = Field( - AuthOptions.no_auth, description="Do we require authorization?" - ) + middlecat_url: Annotated[ + str, + Field( + description="Middlecat server to trust as ID provider", + ), + ] = "https://middlecat.net" - middlecat_url: str = Field( - "https://middlecat.up.railway.app", - description="Middlecat server to trust as ID provider", - ) + admin_email: Annotated[ + str | None, + Field( + description="Email address for a hardcoded admin email (useful for setup and recovery)", + ), + ] = None - admin_email: Optional[str] = Field( - None, - description="Email address for a hardcoded admin email (useful for setup and recovery)", - ) + minio_host: Annotated[str | None, Field()] = None + minio_tls: Annotated[bool, Field()] = False + minio_access_key: Annotated[str | None, Field()] = None + minio_secret_key: Annotated[str | None, Field()] = None @model_validator(mode="after") - def set_ssl(self) -> "Settings": + def set_ssl(self: Any) -> "Settings": if not self.elastic_host: - self.elastic_host = ( - "https" if self.elastic_password else "http" - ) + "://localhost:9200" + self.elastic_host = ("https" if self.elastic_password else "http") + "://localhost:9200" if not self.elastic_verify_ssl: self.elastic_verify_ssl = self.elastic_host not in { "http://localhost:9200", @@ -108,7 +126,7 @@ def set_ssl(self) -> "Settings": } return self - model_config = SettingsConfigDict(env_prefix="amcat4_") + model_config = SettingsConfigDict(env_prefix=ENV_PREFIX) @functools.lru_cache() @@ -123,9 +141,7 @@ def get_settings() -> Settings: def validate_settings(): if get_settings().auth != "no_auth": - if get_settings().host.startswith( - "http://" - ) and not get_settings().host.startswith("http://localhost"): + if get_settings().host.startswith("http://") and not get_settings().host.startswith("http://localhost"): return ( "You have set the host at an http address and enabled authentication." "Authentication through middlecat will not work in your browser" @@ -135,5 +151,5 @@ def validate_settings(): if __name__ == "__main__": # Echo the settings - for k, v in get_settings().dict().items(): - print(f"{Settings.Config.env_prefix.upper()}{k.upper()}={v}") + for k, v in get_settings().model_dump().items(): + print(f"{ENV_PREFIX.upper()}{k.upper()}={v}") diff --git a/amcat4/date_mappings.py b/amcat4/date_mappings.py index ffbfc40..fcb3d6f 100644 --- a/amcat4/date_mappings.py +++ b/amcat4/date_mappings.py @@ -3,13 +3,10 @@ class DateMapping: - interval = None + interval: str | None = None def mapping(self, field: str) -> dict: - return {self.fieldname(field): { - "type": self.mapping_type(), - "script": self.mapping_script(field) - }} + return {self.fieldname(field): {"type": self.mapping_type(), "script": self.mapping_script(field)}} def mapping_script(self, field: str) -> str: raise NotImplementedError() @@ -96,10 +93,12 @@ def postprocess(self, value): return int(value) -def interval_mapping(interval: str) -> Optional[DateMapping]: - for m in mappings(): - if m.interval == interval: - return m +def interval_mapping(interval: str | None) -> Optional[DateMapping]: + if interval is not None: + for m in mappings(): + if m.interval == interval: + return m + return None def mappings() -> Iterable[DateMapping]: diff --git a/amcat4/elastic.py b/amcat4/elastic.py index 2a39f67..563d047 100644 --- a/amcat4/elastic.py +++ b/amcat4/elastic.py @@ -6,43 +6,19 @@ - The elasticsearch backend should contain a system index, which will be created if needed - The system index contains a 'document' for each used index containing: {auth: [{email: role}], guest_role: role} -- We define the mappings (field types) based on existing elasticsearch mappings, - but use field metadata to define specific fields, see ES_MAPPINGS below. + """ + import functools -import hashlib -import json + import logging -from typing import Mapping, List, Iterable, Optional, Tuple, Union, Sequence, Literal +from typing import Optional from elasticsearch import Elasticsearch, NotFoundError -from elasticsearch.helpers import bulk - from amcat4.config import get_settings SYSTEM_INDEX_VERSION = 1 -ES_MAPPINGS = { - "long": {"type": "long"}, - "date": {"type": "date", "format": "strict_date_optional_time"}, - "double": {"type": "double"}, - "keyword": {"type": "keyword"}, - "url": {"type": "keyword", "meta": {"amcat4_type": "url"}}, - "tag": {"type": "keyword", "meta": {"amcat4_type": "tag"}}, - "id": {"type": "keyword", "meta": {"amcat4_type": "id"}}, - "text": {"type": "text"}, - "object": {"type": "object"}, - "geo_point": {"type": "geo_point"}, - "dense_vector_192": {"type": "dense_vector", "dims": 192}, -} - -DEFAULT_MAPPING = { - "text": ES_MAPPINGS["text"], - "title": ES_MAPPINGS["text"], - "date": ES_MAPPINGS["date"], - "url": ES_MAPPINGS["url"], -} - SYSTEM_MAPPING = { "name": {"type": "text"}, "description": {"type": "text"}, @@ -70,10 +46,16 @@ def connect_elastic() -> Elasticsearch: """ settings = get_settings() if settings.elastic_password: + host = settings.elastic_host + if settings.elastic_verify_ssl is None: + verify_certs = "localhost" in (host or "") + else: + verify_certs = settings.elastic_verify_ssl + return Elasticsearch( - settings.elastic_host or None, + host, basic_auth=("elastic", settings.elastic_password), - verify_certs=settings.elastic_verify_ssl, + verify_certs=verify_certs, ) else: return Elasticsearch(settings.elastic_host or None) @@ -130,212 +112,6 @@ def _setup_elastic(): return elastic -def coerce_type_to_elastic(value, ftype): - """ - Coerces values into the respective type in elastic - based on ES_MAPPINGS and elastic field types - """ - if ftype in ["keyword", "constant_keyword", "wildcard", "url", "tag", "text"]: - value = str(value) - elif ftype in [ - "long", - "short", - "byte", - "double", - "float", - "half_float", - "half_float", - "unsigned_long", - ]: - value = float(value) - elif ftype in ["integer"]: - value = int(value) - elif ftype == "boolean": - value = bool(value) - return value - - -def _get_hash(document: dict) -> bytes: - """ - Get the hash for a document - """ - hash_str = json.dumps(document, sort_keys=True, ensure_ascii=True, default=str).encode("ascii") - m = hashlib.sha224() - m.update(hash_str) - return m.hexdigest() - - -def upload_documents(index: str, documents, fields: Mapping[str, str] = None) -> None: - """ - Upload documents to this index - - :param index: The name of the index (without prefix) - :param documents: A sequence of article dictionaries - :param fields: A mapping of field:type for field types - """ - - def es_actions(index, documents): - field_types = get_index_fields(index) - for document in documents: - for key in document.keys(): - if key in field_types: - document[key] = coerce_type_to_elastic(document[key], field_types[key].get("type")) - if "_id" not in document: - document["_id"] = _get_hash(document) - yield {"_index": index, **document} - - if fields: - set_fields(index, fields) - - actions = list(es_actions(index, documents)) - bulk(es(), actions) - - -def get_field_mapping(type_: Union[str, dict]): - if isinstance(type_, str): - return ES_MAPPINGS[type_] - else: - mapping = ES_MAPPINGS[type_["type"]] - meta = mapping.get("meta", {}) - if m := type_.get("meta"): - meta.update(m) - mapping["meta"] = meta - return mapping - - -def set_fields(index: str, fields: Mapping[str, str]): - """ - Update the column types for this index - - :param index: The name of the index (without prefix) - :param fields: A mapping of field:type for column types - """ - properties = {field: get_field_mapping(type_) for (field, type_) in fields.items()} - es().indices.put_mapping(index=index, properties=properties) - - -def get_document(index: str, doc_id: str, **kargs) -> dict: - """ - Get a single document from this index. - - :param index: The name of the index - :param doc_id: The document id (hash) - :return: the source dict of the document - """ - return es().get(index=index, id=doc_id, **kargs)["_source"] - - -def update_document(index: str, doc_id: str, fields: dict): - """ - Update a single document. - - :param index: The name of the index - :param doc_id: The document id (hash) - :param fields: a {field: value} mapping of fields to update - """ - # Mypy doesn't understand that body= has been deprecated already... - es().update(index=index, id=doc_id, doc=fields) # type: ignore - - -def delete_document(index: str, doc_id: str): - """ - Delete a single document - - :param index: The name of the index - :param doc_id: The document id (hash) - """ - es().delete(index=index, id=doc_id) - - -def _get_type_from_property(properties: dict) -> str: - """ - Convert an elastic 'property' into an amcat4 field type - """ - result = properties.get("meta", {}).get("amcat4_type") - properties["type"] = properties.get("type", "object") - if result: - return result - return properties["type"] - - -def _get_fields(index: str) -> Iterable[Tuple[str, dict]]: - r = es().indices.get_mapping(index=index) - for k, v in r[index]["mappings"]["properties"].items(): - t = dict(name=k, type=_get_type_from_property(v)) - if meta := v.get("meta"): - t["meta"] = meta - yield k, t - - -def get_index_fields(index: str) -> Mapping[str, dict]: - """ - Get the field types in use in this index - :param index: - :return: a dict of fieldname: field objects {fieldname: {name, type, meta, ...}] - """ - return dict(_get_fields(index)) - - -def get_fields(index: Union[str, Sequence[str]]): - """ - Get the field types in use in this index or indices - :param index: name(s) of index(es) to query - :return: a dict of fieldname: field objects {fieldname: {name, type, ...}] - """ - if isinstance(index, str): - return get_index_fields(index) - result = {} - for ix in index: - for f, ftype in get_index_fields(ix).items(): - if f in result: - if result[f] != ftype: - result[f] = {"name": f, "type": "keyword", "meta": {"merged": True}} - else: - result[f] = ftype - return result - - -def get_values(index: str, field: str, size: int = 100) -> List[str]: - """ - Get the values for a given field (e.g. to populate list of filter values on keyword field) - :param index: The index - :param field: The field name - :return: A list of values - """ - aggs = {"values": {"terms": {"field": field}}} - r = es().search(index=index, size=size, aggs=aggs) - return [x["key"] for x in r["aggregations"]["values"]["buckets"]] - - -def update_by_query(index: str, script: str, query: dict, params: dict = None): - script = dict(source=script, lang="painless", params=params or {}) - es().update_by_query(index=index, script=script, **query) - - -TAG_SCRIPTS = dict( - add=""" - if (ctx._source[params.field] == null) { - ctx._source[params.field] = [params.tag] - } else if (!ctx._source[params.field].contains(params.tag)) { - ctx._source[params.field].add(params.tag) - } - """, - remove=""" - if (ctx._source[params.field] != null && ctx._source[params.field].contains(params.tag)) { - ctx._source[params.field].removeAll([params.tag]); - if (ctx._source[params.field].size() == 0) { - ctx._source.remove(params.field); - } - }""", -) - - -def update_tag_by_query(index: str, action: Literal["add", "remove"], query: dict, field: str, tag: str): - script = TAG_SCRIPTS[action] - params = dict(field=field, tag=tag) - update_by_query(index, script, query, params) - - def ping(): """ Can we reach this elasticsearch server diff --git a/amcat4/fields.py b/amcat4/fields.py new file mode 100644 index 0000000..8ed86cc --- /dev/null +++ b/amcat4/fields.py @@ -0,0 +1,382 @@ +""" +We have two types of fields: +- Elastic fields are the fields used under the hood by elastic. + (https://www.elastic.co/guide/en/elasticsearch/reference/current/mapping-types.html + These are stored in the Mapping of an index +- Amcat fields (Field) are the fields are seen by the amcat user. They use a simplified type, and contain additional + information such as metareader access + These are stored in the system index. + +We need to make sure that: +- When a user sets a field, it needs to be changed in both types: the system index and the mapping +- If a field only exists in the elastic mapping, we need to add the default Field to the system index. + This happens anytime get_fields is called, so that whenever a field is used it is guarenteed to be in the + system index +""" + +import datetime +import json +from typing import Any, Iterator, Literal, Mapping, get_args, cast + + +from elasticsearch import NotFoundError + +# from amcat4.api.common import py2dict +from amcat4.config import get_settings +from amcat4.elastic import es +from amcat4.models import FieldType, CreateField, ElasticType, Field, UpdateField, FieldMetareaderAccess + + +# given an elastic field type, Check if it is supported by AmCAT. +# this is not just the inverse of TYPEMAP_AMCAT_TO_ES because some AmCAT types map to multiple elastic +# types (e.g., tag and keyword, image_url and wildcard) +# (this is relevant if we are importing an index) +TYPEMAP_ES_TO_AMCAT: dict[ElasticType, FieldType] = { + # TEXT fields + "text": "text", + "annotated_text": "text", + "binary": "text", + "match_only_text": "text", + # DATE fields + "date": "date", + # BOOLEAN fields + "boolean": "boolean", + # KEYWORD fields + "keyword": "keyword", + "constant_keyword": "keyword", + "wildcard": "keyword", + # INTEGER fields + "integer": "number", + "byte": "number", + "short": "number", + "long": "number", + "unsigned_long": "number", + # NUMBER fields + "float": "number", + "half_float": "number", + "double": "number", + "scaled_float": "number", + # OBJECT fields + "object": "object", + "flattened": "object", + "nested": "object", + # VECTOR fields (exclude sparse vectors) + "dense_vector": "vector", + # GEO fields + "geo_point": "geo_point", +} + +# maps amcat field types to elastic field types. +# The first elastic type in the array is the default. +TYPEMAP_AMCAT_TO_ES: dict[FieldType, list[ElasticType]] = { + "text": ["text", "annotated_text", "binary", "match_only_text"], + "date": ["date"], + "boolean": ["boolean"], + "keyword": ["keyword", "constant_keyword", "wildcard"], + "number": ["double", "float", "half_float", "scaled_float"], + "integer": ["long", "integer", "byte", "short", "unsigned_long"], + "object": ["object", "flattened", "nested"], + "vector": ["dense_vector"], + "geo_point": ["geo_point"], + "tag": ["keyword", "wildcard"], + "image": ["wildcard", "keyword", "constant_keyword", "text"], + "video": ["wildcard", "keyword", "constant_keyword", "text"], + "audio": ["wildcard", "keyword", "constant_keyword", "text"], + "url": ["wildcard", "keyword", "constant_keyword", "text"], + "json": ["text"], + "preprocess": ["object"], +} + + +def get_default_metareader(type: FieldType): + if type in ["boolean", "number", "date"]: + return FieldMetareaderAccess(access="read") + + return FieldMetareaderAccess(access="none") + + +def get_default_field(type: FieldType): + """ + Generate a field on the spot with default settings. + Primary use case is importing existing indices with fields that are not registered in the system index. + """ + elastic_type = TYPEMAP_AMCAT_TO_ES.get(type) + if elastic_type is None: + raise ValueError( + f"The default elastic type mapping for field type {type} is not defined (if this happens, blame and inform Kasper)" + ) + return Field(elastic_type=elastic_type[0], type=type, metareader=get_default_metareader(type)) + + +def _standardize_createfields(fields: Mapping[str, FieldType | CreateField]) -> dict[str, CreateField]: + sfields = {} + for k, v in fields.items(): + if isinstance(v, str): + assert v in get_args(FieldType), f"Unknown amcat type {v}" + sfields[k] = CreateField(type=cast(FieldType, v)) + else: + sfields[k] = v + return sfields + + +def check_forbidden_type(field: Field, type: FieldType): + if field.identifier: + for forbidden_type in ["tag", "vector"]: + if type == forbidden_type: + raise ValueError(f"Field {field} is an identifier field, which cannot be a {forbidden_type} field") + + +def coerce_type(value: Any, type: FieldType): + """ + Coerces values into the respective type in elastic + based on ES_MAPPINGS and elastic field types + """ + if type == "date" and isinstance(value, datetime.date): + return value.isoformat() + if type in ["text", "tag", "image", "video", "audio", "date"]: + return str(value) + if type in ["boolean"]: + return bool(value) + if type in ["number"]: + return float(value) + if type in ["integer"]: + return int(value) + + if type == "json": + if isinstance(value, str): + return value + return json.dumps(value) + + # TODO: check coercion / validation for object, vector and geo types + if type in ["object"]: + return value + if type in ["vector"]: + return value + if type in ["geo_point"]: + return value + + return value + + +def create_fields(index: str, fields: Mapping[str, FieldType | CreateField]): + mapping: dict[str, Any] = {} + current_fields = get_fields(index) + + sfields = _standardize_createfields(fields) + old_identifiers = any(f.identifier for f in current_fields.values()) + new_identifiers = False + + for field, settings in sfields.items(): + if settings.elastic_type is not None: + allowed_types = TYPEMAP_AMCAT_TO_ES.get(settings.type, []) + if settings.elastic_type not in allowed_types: + raise ValueError( + f"Field type {settings.type} does not support elastic type {settings.elastic_type}. " + f"Allowed types are: {allowed_types}" + ) + elastic_type = settings.elastic_type + else: + elastic_type = get_default_field(settings.type).elastic_type + + current = current_fields.get(field) + if current is not None: + # fields can already exist. For example, a scraper might include the field types in every + # upload request. If a field already exists, we'll ignore it, but we will throw an error + # if static settings (elastic type, identifier) do not match. + if current.elastic_type != elastic_type: + raise ValueError(f"Field '{field}' already exists with elastic type '{current.elastic_type}'. ") + if current.identifier != settings.identifier: + raise ValueError(f"Field '{field}' already exists with identifier '{current.identifier}'. ") + continue + + # if field does not exist, we add it to both the mapping and the system index + if settings.identifier: + new_identifiers = True + mapping[field] = {"type": elastic_type} + if settings.type == "preprocess": + mapping[field]["properties"] = dict(status=dict(type="keyword"), error=dict(type="text", index=False)) + if settings.type in ["date"]: + mapping[field]["format"] = "strict_date_optional_time" + + current_fields[field] = Field( + type=settings.type, + elastic_type=elastic_type, + identifier=settings.identifier, + metareader=settings.metareader or get_default_metareader(settings.type), + client_settings=settings.client_settings or {}, + ) + check_forbidden_type(current_fields[field], settings.type) + + if new_identifiers: + # new identifiers are only allowed if the index had identifiers, or if it is a new index (i.e. no documents) + has_docs = es().count(index=index)["count"] > 0 + if has_docs and not old_identifiers: + raise ValueError("Cannot add identifiers. Index already has documents with no identifiers.") + + if len(mapping) > 0: + # if there are new identifiers, check whether this is allowed first + es().indices.put_mapping(index=index, properties=mapping) + es().update( + index=get_settings().system_index, + id=index, + doc=dict(fields=_fields_to_elastic(current_fields)), + ) + + +def _fields_to_elastic(fields: dict[str, Field]) -> list[dict]: + # some additional validation + return [{"field": field, "settings": settings.model_dump()} for field, settings in fields.items()] + + +def _fields_from_elastic( + fields: list[dict], +) -> dict[str, Field]: + return {fs["field"]: Field.model_validate(fs["settings"]) for fs in fields} + + +def update_fields(index: str, fields: dict[str, UpdateField]): + """ + Set the fields settings for this index. Only updates fields that + already exist. Only keys in UpdateField can be updated (not type or client_settings) + """ + + current_fields = get_fields(index) + + for field, new_settings in fields.items(): + current = current_fields.get(field) + if current is None: + raise ValueError(f"Field {field} does not exist") + + if new_settings.type is not None: + check_forbidden_type(current, new_settings.type) + + valid_es_types = TYPEMAP_AMCAT_TO_ES.get(new_settings.type) + if valid_es_types is None: + raise ValueError(f"Invalid field type: {new_settings.type}") + if current.elastic_type not in valid_es_types: + raise ValueError( + f"Field {field} has the elastic type {current.elastic_type}. A {new_settings.type} field can only have the following elastic types: {valid_es_types}." + ) + current_fields[field].type = new_settings.type + + if new_settings.metareader is not None: + if current.type != "text" and new_settings.metareader.access == "snippet": + raise ValueError(f"Field {field} is not of type text, cannot set metareader access to snippet") + current_fields[field].metareader = new_settings.metareader + + if new_settings.client_settings is not None: + current_fields[field].client_settings = new_settings.client_settings + + es().update( + index=get_settings().system_index, + id=index, + doc=dict(fields=_fields_to_elastic(current_fields)), + ) + + +def _get_index_fields(index: str) -> Iterator[tuple[str, ElasticType]]: + r = es().indices.get_mapping(index=index) + + if len(r[index]["mappings"]) > 0: + for k, v in r[index]["mappings"]["properties"].items(): + yield k, v.get("type", "object") + + +def get_fields(index: str) -> dict[str, Field]: + """ + Retrieve the fields settings for this index. Look for both the field settings in the system index, + and the field mappings in the index itself. If a field is not defined in the system index, return the + default settings for that field type and add it to the system index. This way, any elastic index can be imported + """ + fields: dict[str, Field] = {} + system_index = get_settings().system_index + + try: + d = es().get( + index=system_index, + id=index, + source_includes="fields", + ) + system_index_fields = _fields_from_elastic(d["_source"].get("fields", {})) + except NotFoundError: + system_index_fields = {} + + update_system_index = False + for field, elastic_type in _get_index_fields(index): + type = TYPEMAP_ES_TO_AMCAT.get(elastic_type) + + if type is None: + # skip over unsupported elastic fields. + # (TODO: also return warning to client?) + continue + + if field not in system_index_fields: + update_system_index = True + fields[field] = get_default_field(type) + else: + fields[field] = system_index_fields[field] + + if update_system_index: + es().update( + index=system_index, + id=index, + doc=dict(fields=_fields_to_elastic(fields)), + ) + + return fields + + +def create_or_verify_tag_field(index: str | list[str], field: str): + """Create a special type of field that can be used to tag documents. + Since adding/removing tags supports multiple indices, we first check whether the field name is valid for all indices""" + indices = [index] if isinstance(index, str) else index + add_to_indices = [] + for i in indices: + current_fields = get_fields(i) + if field in current_fields: + if current_fields[field].type != "tag": + raise ValueError(f"Field '{field}' already exists in index '{i}' and is not a tag field") + + else: + add_to_indices.append(i) + + for i in add_to_indices: + current_fields[field] = get_default_field("tag") + es().indices.put_mapping(index=index, properties={field: {"type": "keyword"}}) + es().update( + index=get_settings().system_index, + id=i, + doc=dict(fields=_fields_to_elastic(current_fields)), + ) + + +def field_values(index: str, field: str, size: int) -> list[str]: + """ + Get the values for a given field (e.g. to populate list of filter values on keyword field) + Results are sorted descending by document frequency + see: https://www.elastic.co/guide/en/elasticsearch/reference/7.4/search-aggregations-bucket-terms-aggregation.html + #search-aggregations-bucket-terms-aggregation-order + + :param index: The index + :param field: The field name + :return: A list of values + """ + aggs = {"unique_values": {"terms": {"field": field, "size": size}}} + r = es().search(index=index, size=0, aggs=aggs) + return [x["key"] for x in r["aggregations"]["unique_values"]["buckets"]] + + +def field_stats(index: str, field: str) -> list[str]: + """ + :param index: The index + :param field: The field name + :return: A list of values + """ + aggs = {"facets": {"stats": {"field": field}}} + r = es().search(index=index, size=0, aggs=aggs) + return r["aggregations"]["facets"] + + +def update_by_query(index: str | list[str], script: str, query: dict, params: dict | None = None): + script_dict = dict(source=script, lang="painless", params=params or {}) + es().update_by_query(index=index, script=script_dict, **query) diff --git a/amcat4/index.py b/amcat4/index.py index 7477b9b..1aa4690 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -28,30 +28,60 @@ - This system index contains a 'document' for each index: {name: "...", description:"...", guest_role: "...", roles: [{email, role}...]} - A special _global document defines the global properties for this instance (name, roles) +- We define the mappings (field types) based on existing elasticsearch mappings, + but use field metadata to define specific fields. """ + import collections +from curses import meta +from dataclasses import field from enum import IntEnum -from typing import Dict, Iterable, List, Optional +import functools +import logging +from typing import Any, Iterable, Mapping, Optional, Literal + +import hashlib +import json import elasticsearch.helpers from elasticsearch import NotFoundError -from amcat4.config import get_settings -from amcat4.elastic import DEFAULT_MAPPING, es, get_fields +# from amcat4.api.common import py2dict +from amcat4.config import AuthOptions, get_settings +from amcat4.elastic import es +from amcat4.fields import ( + coerce_type, + create_fields, + create_or_verify_tag_field, + get_fields, +) +from amcat4.models import CreateField, Field, FieldType +from amcat4.preprocessing.models import PreprocessingInstruction +from amcat4.preprocessing import processor class Role(IntEnum): + NONE = 0 METAREADER = 10 READER = 20 WRITER = 30 ADMIN = 40 +class GuestRole(IntEnum): + NONE = 0 + METAREADER = 10 + READER = 20 + WRITER = 30 + + +ADMIN_USER = "_admin" GUEST_USER = "_guest" GLOBAL_ROLES = "_global" Index = collections.namedtuple( - "Index", ["id", "name", "description", "guest_role", "roles", "summary_field"] + "Index", + ["id", "name", "description", "guest_role", "archived", "roles", "summary_field"], ) @@ -73,7 +103,7 @@ def refresh_system_index(): es().indices.refresh(index=get_settings().system_index) -def list_known_indices(email: str = None) -> Iterable[Index]: +def list_known_indices(email: str | None = None) -> Iterable[Index]: """ List all known indices, e.g. indices registered in this amcat4 instance :param email: if given, only list indices visible to this user @@ -84,14 +114,8 @@ def list_known_indices(email: str = None) -> Iterable[Index]: # "must_not": {"term": {"guest_role": {"value": "none", "case_insensitive": True}}}}} # q_role = {"nested": {"path": "roles", "query": {"term": {"roles.email": email}}}} # query = {"bool": {"should": [q_guest, q_role]}} - check_role = not ( - email is None - or get_global_role(email) == Role.ADMIN - or get_settings().auth == "no_auth" - ) - for index in elasticsearch.helpers.scan( - es(), index=get_settings().system_index, fields=[], _source=True - ): + check_role = not (email is None or get_global_role(email) == Role.ADMIN or get_settings().auth == "no_auth") + for index in elasticsearch.helpers.scan(es(), index=get_settings().system_index, fields=[], _source=True): ix = _index_from_elastic(index) if ix.name == GLOBAL_ROLES: continue @@ -101,12 +125,14 @@ def list_known_indices(email: str = None) -> Iterable[Index]: def _index_from_elastic(index): src = index["_source"] - guest_role = src.get("guest_role") + guest_role = src.get("guest_role", "NONE") + return Index( id=index["_id"], name=src.get("name", index["_id"]), description=src.get("description"), - guest_role=guest_role and guest_role != "NONE" and Role[guest_role.upper()], + guest_role=Role[guest_role] if guest_role in Role.__members__ else Role.NONE, + archived=src.get("archived"), roles=_roles_from_elastic(src.get("roles", [])), summary_field=src.get("summary_field"), ) @@ -114,10 +140,10 @@ def _index_from_elastic(index): def get_index(index: str) -> Index: try: - index = es().get(index=get_settings().system_index, id=index) + index_resp = es().get(index=get_settings().system_index, id=index) except NotFoundError: raise IndexDoesNotExist(index) - return _index_from_elastic(index) + return _index_from_elastic(index_resp) def create_index( @@ -130,9 +156,19 @@ def create_index( """ Create a new index in elasticsearch and register it with this AmCAT instance """ - es().indices.create(index=index, mappings={"properties": DEFAULT_MAPPING}) + try: + get_index(index) + raise ValueError(f'Index "{index}" already exists') + except IndexDoesNotExist: + pass + + es().indices.create(index=index, mappings={"properties": {}}) register_index( - index, guest_role=guest_role, name=name, description=description, admin=admin + index, + guest_role=guest_role or Role.NONE, + name=name or index, + description=description or "", + admin=admin, ) @@ -152,6 +188,7 @@ def register_index( if es().exists(index=system_index, id=index): raise ValueError(f"Index {index} is already registered") roles = [dict(email=admin, role="ADMIN")] if admin else [] + es().index( index=system_index, id=index, @@ -159,7 +196,7 @@ def register_index( name=(name or index), roles=roles, description=description, - guest_role=guest_role and guest_role.name, + guest_role=guest_role.name if guest_role is not None else "NONE", ), ) refresh_index(system_index) @@ -171,9 +208,9 @@ def delete_index(index: str, ignore_missing=False) -> None: :param index: The name of the index :param ignore_missing: If True, do not throw exception if index does not exist """ - deregister_index(index, ignore_missing=ignore_missing) _es = es().options(ignore_status=404) if ignore_missing else es() _es.indices.delete(index=index) + deregister_index(index, ignore_missing=ignore_missing) def deregister_index(index: str, ignore_missing=False) -> None: @@ -190,13 +227,17 @@ def deregister_index(index: str, ignore_missing=False) -> None: raise else: refresh_index(system_index) + # Stop preprocessing loops on this index + from amcat4.preprocessing.processor import get_manager + + get_manager().remove_index_preprocessors(index) -def _roles_from_elastic(roles: List[Dict]) -> Dict[str, Role]: +def _roles_from_elastic(roles: list[dict]) -> dict[str, Role]: return {role["email"]: Role[role["role"].upper()] for role in roles} -def _roles_to_elastic(roles: dict) -> List[Dict]: +def _roles_to_elastic(roles: dict) -> list[dict]: return [{"email": email, "role": role.name} for (email, role) in roles.items()] @@ -210,7 +251,7 @@ def set_role(index: str, email: str, role: Optional[Role]): try: d = es().get(index=system_index, id=index, source_includes="roles") except NotFoundError: - raise ValueError(f"Index {index} does is not registered") + raise ValueError(f"Index {index} is not registered") roles_dict = _roles_from_elastic(d["_source"].get("roles", [])) if role: roles_dict[email] = role @@ -218,50 +259,43 @@ def set_role(index: str, email: str, role: Optional[Role]): if email not in roles_dict: return # Nothing to change del roles_dict[email] + es().update( - index=system_index, id=index, doc=dict(roles=_roles_to_elastic(roles_dict)) + index=system_index, + id=index, + doc=dict(roles=_roles_to_elastic(roles_dict)), ) -def set_global_role(email: str, role: Role): +def set_global_role(email: str, role: Role | None): """ Set the global role for this user """ set_role(index=GLOBAL_ROLES, email=email, role=role) -def set_guest_role(index: str, guest_role: Optional[Role]): +def set_guest_role(index: str, guest_role: Optional[GuestRole]): """ Set the guest role for this index. Set to None to disallow guest access """ - modify_index(index, guest_role=guest_role, remove_guest_role=(guest_role is None)) + modify_index(index, guest_role=GuestRole.NONE if guest_role is None else guest_role) def modify_index( index: str, name: Optional[str] = None, description: Optional[str] = None, - guest_role: Optional[Role] = None, - remove_guest_role=False, - summary_field=None, + guest_role: Optional[GuestRole] = None, + archived: Optional[str] = None, ): doc = dict( name=name, description=description, - guest_role=guest_role and guest_role.name, - summary_field=summary_field, + guest_role=guest_role.name if guest_role is not None else None, + archived=archived, ) - if summary_field is not None: - f = get_fields(index) - if summary_field not in f: - raise ValueError(f"Summary field {summary_field} does not exist!") - if f[summary_field]["type"] not in ["date", "keyword", "tag"]: - raise ValueError( - f"Summary field {summary_field} should be date, keyword or tag, not {f[summary_field]['type']}!" - ) - doc = {x: v for (x, v) in doc.items() if v} - if remove_guest_role: - doc["guest_role"] = None + + doc = {x: v for (x, v) in doc.items() if v is not None} if doc: es().update(index=get_settings().system_index, id=index, doc=doc) @@ -280,7 +314,23 @@ def remove_global_role(email: str): remove_role(index=GLOBAL_ROLES, email=email) -def get_role(index: str, email: str) -> Optional[Role]: +def user_exists(email: str, index: str = GLOBAL_ROLES) -> bool: + """ + Check if a user exists on server (GLOBAL_ROLES) or in a specific index + """ + try: + doc = es().get( + index=get_settings().system_index, + id=index, + source_includes=["roles", "guest_role"], + ) + except NotFoundError: + raise IndexDoesNotExist(f"Index {index} does not exist or is not registered") + roles_dict = _roles_from_elastic(doc["_source"].get("roles", [])) + return email in roles_dict + + +def get_role(index: str, email: str) -> Role: """ Retrieve the role of this user on this index, or the guest role if user has no role Raises a ValueError if the index does not exist @@ -298,41 +348,51 @@ def get_role(index: str, email: str) -> Optional[Role]: if role := roles_dict.get(email): return role if index == GLOBAL_ROLES: - return None - role = doc["_source"].get("guest_role", None) - if role and role.lower() != "none": - return Role[role] + return Role.NONE + + # are guests allowed? + if get_settings().auth == AuthOptions.authorized_users_only: + # only allow guests if authorized at server level + global_role = get_global_role(email, only_es=True) + if global_role == Role.NONE: + return Role.NONE + return get_guest_role(index) -def get_guest_role(index: str) -> Optional[Role]: + +def get_guest_role(index: str) -> Role: """ Return the guest role for this index, raising a IndexDoesNotExist if the index does not exist :returns: a Role object, or None if global role was NONE """ try: d = es().get( - index=get_settings().system_index, id=index, source_includes="guest_role" + index=get_settings().system_index, + id=index, + source_includes="guest_role", ) except NotFoundError: raise IndexDoesNotExist(index) role = d["_source"].get("guest_role") - if role and role.lower() != "none": + if role and role in Role.__members__: return Role[role] + return Role.NONE -def get_global_role(email: str) -> Optional[Role]: +def get_global_role(email: str, only_es: bool = False) -> Role: """ Retrieve the global role of this user :returns: a Role object, or None if the user has no role """ # The 'admin' user is given to everyone in the no_auth scenario - if email == get_settings().admin_email or email == "admin": - return Role.ADMIN + if only_es is False: + if email == get_settings().admin_email or email == ADMIN_USER: + return Role.ADMIN return get_role(index=GLOBAL_ROLES, email=email) -def list_users(index: str) -> Dict[str, Role]: +def list_users(index: str) -> dict[str, Role]: """ " List all users and their roles on the given index :param index: The index to list roles for. @@ -342,7 +402,7 @@ def list_users(index: str) -> Dict[str, Role]: return _roles_from_elastic(r["_source"].get("roles", [])) -def list_global_users() -> Dict[str, Role]: +def list_global_users() -> dict[str, Role]: """ " List all global users and their roles :returns: an iterable of (user, Role) pairs @@ -355,3 +415,210 @@ def delete_user(email: str) -> None: set_global_role(email, None) for ix in list_known_indices(email): set_role(ix.id, email, None) + + +def create_id(document: dict, field_settings: dict[str, Field]) -> str: + """ + Create the _id for a document. + """ + + identifiers = [k for k, v in field_settings.items() if v.identifier == True] + if len(identifiers) == 0: + raise ValueError("Can only create id if identifiers are specified") + + id_keys = sorted(set(identifiers) & set(document.keys())) + id_fields = {k: document[k] for k in id_keys} + hash_str = json.dumps(id_fields, sort_keys=True, ensure_ascii=True, default=str).encode("ascii") + m = hashlib.sha224() + m.update(hash_str) + return m.hexdigest() + + +def upload_documents( + index: str, + documents: list[dict[str, Any]], + fields: Mapping[str, FieldType | CreateField] | None = None, + op_type: Literal["create", "update"] = "create", + raise_on_error=False, +): + """ + Upload documents to this index + + :param index: The name of the index (without prefix) + :param documents: A sequence of article dictionaries + :param fields: A mapping of fieldname:UpdateField for field types + :param op_type: Whether to 'index' new documents (default) or 'update' existing documents + """ + if fields: + create_fields(index, fields) + + def es_actions(index, documents, op_type): + field_settings = get_fields(index) + has_identifiers = any(field.identifier for field in field_settings.values()) + for document in documents: + doc = dict() + action = {"_op_type": op_type, "_index": index} + + for key in document.keys(): + if key in field_settings: + doc[key] = coerce_type(document[key], field_settings[key].type) + else: + if key != "_id": + raise ValueError(f"Field '{key}' is not yet specified") + + if key == "_id": + if has_identifiers: + identifiers = ", ".join([name for name, field in field_settings.items() if field.identifier]) + raise ValueError(f"This index uses identifier ({identifiers}), so you cannot set the _id directly.") + action["_id"] = document["_id"] + else: + if has_identifiers: + action["_id"] = create_id(document, field_settings) + ## if no id is given, elasticsearch creates a cool unique one + + # https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html + if op_type == "update": + if "_id" not in action: + raise ValueError("Update requires _id") + action["doc"] = doc + action["doc_as_upsert"] = True + else: + action.update(doc) + + yield action + + actions = list(es_actions(index, documents, op_type)) + try: + successes, failures = elasticsearch.helpers.bulk( + es(), + actions, + stats_only=True, + raise_on_error=raise_on_error, + ) + except elasticsearch.helpers.BulkIndexError as e: + logging.error("Error on indexing: " + json.dumps(e.errors, indent=2, default=str)) + if e.errors: + _, error = list(e.errors[0].items())[0] + reason = error.get("error", {}).get("reason", error) + e.args = e.args + (f"First error: {reason}",) + raise + + # Start preprocessors for this index (if any) + processor.get_manager().start_index_preprocessors(index) + + return dict(successes=successes, failures=failures) + + +def get_document(index: str, doc_id: str, **kargs) -> dict: + """ + Get a single document from this index. + + :param index: The name of the index + :param doc_id: The document id (hash) + :return: the source dict of the document + """ + return es().get(index=index, id=doc_id, **kargs)["_source"] + + +def update_document(index: str, doc_id: str, fields: dict): + """ + Update a single document. + + :param index: The name of the index + :param doc_id: The document id (hash) + :param fields: a {field: value} mapping of fields to update + """ + # Mypy doesn't understand that body= has been deprecated already... + es().update(index=index, id=doc_id, doc=fields) # type: ignore + + +def delete_document(index: str, doc_id: str): + """ + Delete a single document + + :param index: The Pname of the index + :param doc_id: The document id (hash) + """ + es().delete(index=index, id=doc_id) + + +def update_by_query(index: str | list[str], script: str, query: dict, params: dict | None = None): + script_dict = dict(source=script, lang="painless", params=params or {}) + result = es().update_by_query(index=index, script=script_dict, **query, refresh=True) + return dict(updated=result["updated"], total=result["total"]) + + +UDATE_SCRIPTS = dict( + add=""" + if (ctx._source[params.field] == null) { + ctx._source[params.field] = [params.tag] + } else { + if (ctx._source[params.field].contains(params.tag)) { + ctx.op = 'noop'; + } else { + ctx._source[params.field].add(params.tag) + } + } + """, + remove=""" + if (ctx._source[params.field] != null && ctx._source[params.field].contains(params.tag)) { + ctx._source[params.field].removeAll([params.tag]); + if (ctx._source[params.field].size() == 0) { + ctx._source.remove(params.field); + } + } else { + ctx.op = 'noop'; + } + """, +) + + +def update_tag_by_query(index: str | list[str], action: Literal["add", "remove"], query: dict, field: str, tag: str): + create_or_verify_tag_field(index, field) + script = UDATE_SCRIPTS[action] + params = dict(field=field, tag=tag) + return update_by_query(index, script, query, params) + + +### WvA Should probably move these to multimedia/actions or something + + +def get_instructions(index: str) -> Iterable[PreprocessingInstruction]: + res = es().get(index=get_settings().system_index, id=index, source="preprocessing") + for i in res["_source"].get("preprocessing", []): + for a in i.get("arguments", []): + if a.get("secret"): + a["value"] = "********" + yield PreprocessingInstruction.model_validate(i) + + +def get_instruction(index: str, field: str) -> Optional[PreprocessingInstruction]: + for i in get_instructions(index): + if i.field == field: + return i + + +def add_instruction(index: str, instruction: PreprocessingInstruction): + if instruction.field in get_fields(index): + raise ValueError(f"Field {instruction.field} already exists in index {index}") + instructions = list(get_instructions(index)) + instructions.append(instruction) + create_fields(index, {instruction.field: "preprocess"}) + body = [i.model_dump() for i in instructions] + es().update(index=get_settings().system_index, id=index, doc=dict(preprocessing=body)) + processor.get_manager().add_preprocessor(index, instruction) + + +def reassign_preprocessing_errors(index: str, field: str): + """Reset status for any documents with error status, and restart preprocessor""" + query = dict(query=dict(term={f"{field}.status": dict(value="error")})) + update_by_query(index, "ctx._source[params.field] = null", query, dict(field=field)) + processor.get_manager().start_preprocessor(index, field) + + +def stop_preprocessor(index: str, field: str): + processor.get_manager().stop_preprocessor(index, field) + + +def start_preprocessor(index: str, field: str): + processor.get_manager().start_preprocessor(index, field) diff --git a/amcat4/models.py b/amcat4/models.py new file mode 100644 index 0000000..6eb8d87 --- /dev/null +++ b/amcat4/models.py @@ -0,0 +1,136 @@ +from curses import OK +from xml.dom.domreg import registered +import pydantic +from pydantic import BaseModel, field_validator, model_validator, validator +from typing import Annotated, Any, Literal +from typing_extensions import Self + + +FieldType = Literal[ + "text", + "date", + "boolean", + "keyword", + "number", + "integer", + "object", + "vector", + "geo_point", + "image", + "video", + "audio", + "tag", + "json", + "url", + "preprocess", +] +ElasticType = Literal[ + "text", + "annotated_text", + "binary", + "match_only_text", + "date", + "boolean", + "keyword", + "constant_keyword", + "wildcard", + "integer", + "byte", + "short", + "long", + "unsigned_long", + "float", + "half_float", + "double", + "scaled_float", + "object", + "flattened", + "nested", + "dense_vector", + "geo_point", +] + + +class SnippetParams(BaseModel): + """ + Snippet parameters for a specific field. + nomatch_chars is the number of characters to show if there is no query match. This is always + the first [nomatch_chars] of the field. + """ + + nomatch_chars: Annotated[int, pydantic.Field(ge=1)] = 100 + max_matches: Annotated[int, pydantic.Field(ge=0)] = 0 + match_chars: Annotated[int, pydantic.Field(ge=1)] = 50 + + +class FieldMetareaderAccess(BaseModel): + """Metareader access for a specific field.""" + + access: Literal["none", "read", "snippet"] = "none" + max_snippet: SnippetParams | None = None + + +class Field(BaseModel): + """Settings for a field. Some settings, such as metareader, have a strict type because they are used + server side. Others, such as client_settings, are free-form and can be used by the client to store settings.""" + + type: FieldType + elastic_type: ElasticType + identifier: bool = False + metareader: FieldMetareaderAccess = FieldMetareaderAccess() + client_settings: dict[str, Any] = {} + + @model_validator(mode="after") + def validate_type(self) -> Self: + if self.identifier: + # Identifiers have to be immutable. Instead of checking this in every endpoint that performs updates, + # we can disable it for certain types that are known to be mutable. + for forbidden_type in ["tag"]: + if self.type == forbidden_type: + raise ValueError(f"Field type {forbidden_type} cannot be used as an identifier") + return self + + +class CreateField(BaseModel): + """Model for creating a field""" + + type: FieldType + elastic_type: ElasticType | None = None + identifier: bool = False + metareader: FieldMetareaderAccess | None = None + client_settings: dict[str, Any] | None = None + + +class UpdateField(BaseModel): + """Model for updating a field""" + + type: FieldType | None = None + metareader: FieldMetareaderAccess | None = None + client_settings: dict[str, Any] | None = None + + +FilterValue = str | int + + +class FilterSpec(BaseModel): + """Form for filter specification.""" + + values: list[FilterValue] | None = None + gt: FilterValue | None = None + lt: FilterValue | None = None + gte: FilterValue | None = None + lte: FilterValue | None = None + exists: bool | None = None + + +class FieldSpec(BaseModel): + """Form for field specification.""" + + name: str + snippet: SnippetParams | None = None + + +class SortSpec(BaseModel): + """Form for sort specification.""" + + order: Literal["asc", "desc"] = "asc" diff --git a/amcat4/multimedia.py b/amcat4/multimedia.py new file mode 100644 index 0000000..58c016c --- /dev/null +++ b/amcat4/multimedia.py @@ -0,0 +1,123 @@ +""" +Multimedia features for AmCAT + +AmCAT can link to a minio/S3 object store to provide access to multimedia content attached to documents. +The object store needs to be configured in the server settings. +""" + +import datetime +from io import BytesIO +from multiprocessing import Value +from typing import Iterable, Optional +from venv import create +from amcat4.config import get_settings +from minio import Minio, S3Error +from minio.deleteobjects import DeleteObject +from minio.datatypes import PostPolicy, Object +import functools + + +def get_minio() -> Minio: + result = connect_minio() + if result is None: + raise ValueError("Could not connect to minio") + return result + + +@functools.lru_cache() +def connect_minio() -> Optional[Minio]: + try: + return _connect_minio() + except Exception as e: + raise Exception(f"Cannot connect to minio {get_settings().minio_host!r}: {e}") + + +def _connect_minio() -> Optional[Minio]: + settings = get_settings() + if settings.minio_host is None: + return None + if settings.minio_secret_key is None or settings.minio_access_key is None: + raise ValueError("minio_access_key or minio_secret_key not specified") + return Minio( + settings.minio_host, + secure=settings.minio_tls, + access_key=settings.minio_access_key, + secret_key=settings.minio_secret_key, + ) + + +def bucket_name(index: str) -> str: + return index.replace("_", "-") + + +def get_bucket(minio: Minio, index: str, create_if_needed=True): + """ + Get the bucket name for this index. If create_if_needed is True, create the bucket if it doesn't exist. + Returns the bucket name, or "" if it doesn't exist and create_if_needed is False. + """ + bucket = bucket_name(index) + if not minio.bucket_exists(bucket): + if not create_if_needed: + return "" + minio.make_bucket(bucket) + return bucket + + +def list_multimedia_objects( + index: str, prefix: Optional[str] = None, start_after: Optional[str] = None, recursive=True +) -> Iterable[Object]: + minio = get_minio() + bucket = get_bucket(minio, index, create_if_needed=False) + if not bucket: + return + yield from minio.list_objects(bucket_name(index), prefix=prefix, start_after=start_after, recursive=recursive) + + +def stat_multimedia_object(index: str, key: str) -> Object: + minio = get_minio() + bucket = get_bucket(minio, index, create_if_needed=False) + if not bucket: + raise ValueError(f"Bucket for {index} does not exist") + return minio.stat_object(bucket, key) + + +def get_multimedia_object(index: str, key: str) -> bytes: + minio = get_minio() + bucket = get_bucket(minio, index, create_if_needed=False) + if not bucket: + raise ValueError(f"Bucket for {index} does not exist") + res = minio.get_object(bucket, key) + return res.read() + + +def delete_bucket(minio: Minio, index: str): + bucket = get_bucket(minio, index, create_if_needed=False) + if not bucket: + return + to_delete = [DeleteObject(x.object_name) for x in minio.list_objects(bucket, recursive=True) if x.object_name] + errors = list(minio.remove_objects(bucket, to_delete)) + if errors: + raise Exception(f"Error on deleting objects: {errors}") + minio.remove_bucket(bucket) + + +def add_multimedia_object(index: str, key: str, bytes: bytes): + minio = get_minio() + bucket = get_bucket(minio, index) + data = BytesIO(bytes) + minio.put_object(bucket, key, data, len(bytes)) + + +def presigned_post(index: str, key_prefix: str = "", days_valid=1): + minio = get_minio() + bucket = get_bucket(minio, index) + policy = PostPolicy(bucket, expiration=datetime.datetime.now() + datetime.timedelta(days=days_valid)) + policy.add_starts_with_condition("key", key_prefix) + url = f"http{'s' if get_settings().minio_tls else ''}://{get_settings().minio_host}/{bucket}" + return url, minio.presigned_post_policy(policy) + + +def presigned_get(index: str, key, days_valid=1): + minio = get_minio() + bucket = get_bucket(minio, index) + return minio.presigned_get_object(bucket, key, expires=datetime.timedelta(days=days_valid)) diff --git a/amcat4/preprocessing/models.py b/amcat4/preprocessing/models.py new file mode 100644 index 0000000..55f25e0 --- /dev/null +++ b/amcat4/preprocessing/models.py @@ -0,0 +1,86 @@ +import copy +from multiprocessing import Value +from typing import Any, Iterable, List, Optional, Tuple + +import httpx +from pydantic import BaseModel + +from amcat4 import multimedia +from amcat4.fields import get_fields +from amcat4.preprocessing.task import get_task + + +class PreprocessingArgument(BaseModel): + name: str + field: Optional[str] = None + value: Optional[str | int | bool | float | List[str] | List[int] | List[float]] = None + secret: Optional[bool] = False + + +class PreprocessingOutput(BaseModel): + name: str + field: str + + +class PreprocessingInstruction(BaseModel): + field: str + task: str + endpoint: str + arguments: List[PreprocessingArgument] + outputs: List[PreprocessingOutput] + + def build_request(self, index, doc) -> httpx.Request: + # TODO: validate that instruction is valid for task! + fields = get_fields(index) + task = get_task(self.task) + if task.request.body == "json": + if not task.request.template: + raise ValueError(f"Task {task.name} has json body but not template") + body = copy.deepcopy(task.request.template) + elif task.request.body == "binary": + body = None + else: + raise NotImplementedError() + headers = {} + for argument in self.arguments: + param = task.get_parameter(argument.name) + if param.use_field == "yes": + if not argument.field: + raise ValueError("Field not given for field param") + value = doc.get(argument.field) + if task.request.body == "binary" and fields[argument.field].type in ["image"]: + value = multimedia.get_multimedia_object(index, value) + else: + value = argument.value + if param.header: + if not param.path: + raise ValueError("Path required for header params") + if ":" in param.path: + path, prefix = param.path.split(":", 1) + prefix = f"{prefix} " + else: + path, prefix = param.path, "" + headers[path] = f"{prefix}{value}" + else: + if task.request.body == "json": + if not param.path: + raise ValueError("Path required for json body params") + param.parsed.update(body, value) + elif task.request.body == "binary": + if param.path: + raise ValueError("Path not allowed for binary body") + if body: + raise ValueError("Multiple values for body") + if type(value) != bytes: + raise ValueError("Binary request requires multimedia object") + body = value + if task.request.body == "json": + return httpx.Request("POST", self.endpoint, json=body, headers=headers) + else: + return httpx.Request("POST", self.endpoint, content=body, headers=headers) + + def parse_output(self, output) -> Iterable[Tuple[str, Any]]: + task = get_task(self.task) + for arg in self.outputs: + o = task.get_output(arg.name) + yield arg.field, o.parsed.find(output)[0].value diff --git a/amcat4/preprocessing/processor.py b/amcat4/preprocessing/processor.py new file mode 100644 index 0000000..bfcb653 --- /dev/null +++ b/amcat4/preprocessing/processor.py @@ -0,0 +1,181 @@ +import asyncio +from functools import cache +import logging +from typing import Dict, Literal, Tuple +from elasticsearch import NotFoundError +from httpx import AsyncClient, HTTPStatusError +from amcat4.elastic import es +import amcat4.index +from amcat4.preprocessing.models import PreprocessingInstruction + +logger = logging.getLogger("amcat4.preprocessing") + +PreprocessorStatus = Literal["Active", "Paused", "Unknown", "Error", "Stopped", "Done"] + + +class RateLimit(Exception): + pass + + +PAUSE_ON_RATE_LIMIT_SECONDS = 10 + + +class PreprocessorManager: + SINGLETON = None + + def __init__(self): + self.preprocessors: Dict[Tuple[str, str], PreprocessingInstruction] = {} + self.running_tasks: Dict[Tuple[str, str], asyncio.Task] = {} + self.preprocessor_status: Dict[Tuple[str, str], PreprocessorStatus] = {} + + def set_status(self, index: str, field: str, status: PreprocessorStatus): + self.preprocessor_status[index, field] = status + + def add_preprocessor(self, index: str, instruction: PreprocessingInstruction): + """Start a new preprocessor task and add it to the manager, returning the Task object""" + self.preprocessors[index, instruction.field] = instruction + self.start_preprocessor(index, instruction.field) + + def start_preprocessor(self, index: str, field: str): + if existing_task := self.running_tasks.get((index, field)): + if not existing_task.done: + return existing_task + instruction = self.preprocessors[index, field] + task = asyncio.create_task(run_processor_loop(index, instruction)) + self.running_tasks[index, instruction.field] = task + + def start_index_preprocessors(self, index: str): + for ix, field in self.preprocessors: + if ix == index: + self.start_preprocessor(index, field) + + def stop_preprocessor(self, index: str, field: str): + """Stop a preprocessor task""" + try: + if task := self.running_tasks.get((index, field)): + task.cancel() + except: + logger.exception(f"Error on cancelling preprocessor {index}:{field}") + + def remove_preprocessor(self, index: str, field: str): + """Stop this preprocessor remove them from the manager""" + self.stop_preprocessor(index, field) + del self.preprocessor_status[index, field] + del self.preprocessors[index, field] + + def remove_index_preprocessors(self, index: str): + """Stop all preprocessors on this index and remove them from the manager""" + for ix, field in list(self.preprocessors.keys()): + if index == ix: + self.remove_preprocessor(ix, field) + + def get_status(self, index: str, field: str) -> PreprocessorStatus: + status = self.preprocessor_status.get((index, field), "Unknown") + task = self.running_tasks.get((index, field)) + if (not task) or task.done() and status == "Active": + logger.warning(f"Preprocessor {index}.{field} is {status}, but has no running task: {task}") + return "Unknown" + return status + + +@cache +def get_manager(): + return PreprocessorManager() + + +def start_processors(): + logger.info("Starting preprocessing loops (if needed)") + manager = get_manager() + for index in amcat4.index.list_known_indices(): + try: + instructions = list(amcat4.index.get_instructions(index.id)) + except NotFoundError: + logger.warning(f"Index {index.id} does not exist!") + continue + for instruction in instructions: + manager.add_preprocessor(index.id, instruction) + + +async def run_processor_loop(index, instruction: PreprocessingInstruction): + """ + Main preprocessor loop. + Calls process_documents to process a batch of documents, until 'done' + """ + logger.info(f"Preprocessing START for {index}.{instruction.field}") + get_manager().set_status(index, instruction.field, "Active") + done = False + while not done: + try: + done = await process_documents(index, instruction) + except asyncio.CancelledError: + logger.info(f"Preprocessing CANCEL for {index}.{instruction.field} cancelled") + get_manager().set_status(index, instruction.field, "Stopped") + raise + except RateLimit: + logger.info(f"Peprocessing RATELIMIT for {index}.{instruction.field}") + get_manager().set_status(index, instruction.field, "Paused") + await asyncio.sleep(PAUSE_ON_RATE_LIMIT_SECONDS) + get_manager().set_status(index, instruction.field, "Active") + except Exception: + logger.exception(f"Preprocessing ERROR for {index}.{instruction.field}") + get_manager().set_status(index, instruction.field, "Error") + return + get_manager().set_status(index, instruction.field, "Done") + logger.info(f"Preprocessing DONE for {index}.{instruction.field}") + + +async def process_documents(index: str, instruction: PreprocessingInstruction, size=100): + """ + Process a batch of currently to-do documents in the index for this instruction. + Return value indicates job completion: + It returns True when it runs out of documents to do, or False if there might be more documents. + """ + # Refresh index before getting new documents to make sure status updates are reflected + amcat4.index.refresh_index(index) + docs = list(get_todo(index, instruction, size=size)) + if not docs: + return True + logger.debug(f"Preprocessing for {index}.{instruction.field}: retrieved {len(docs)} docs to process") + for doc in docs: + await process_doc(index, instruction, doc) + return False + + +def get_todo(index: str, instruction: PreprocessingInstruction, size=100): + fields = [arg.field for arg in instruction.arguments if arg.field] + q = dict(bool=dict(must_not=dict(exists=dict(field=instruction.field)))) + for doc in es().search(index=index, size=size, source_includes=fields, query=q)["hits"]["hits"]: + yield {"_id": doc["_id"], **doc["_source"]} + + +def get_counts(index: str, field: str): + agg = dict(status=dict(terms=dict(field=f"{field}.status"))) + + res = es().search(index=index, size=0, aggs=agg) + result = dict(total=res["hits"]["total"]["value"]) + for bucket in res["aggregations"]["status"]["buckets"]: + result[bucket["key"]] = bucket["doc_count"] + return result + + +async def process_doc(index: str, instruction: PreprocessingInstruction, doc: dict): + # TODO catch errors and add to status field, rather than raising + try: + req = instruction.build_request(index, doc) + except Exception as e: + logger.exception(f"Error on preprocessing {index}.{instruction.field} doc {doc['_id']}") + amcat4.index.update_document(index, doc["_id"], {instruction.field: dict(status="error", error=str(e))}) + return + try: + response = await AsyncClient().send(req) + response.raise_for_status() + except HTTPStatusError as e: + if e.response.status_code == 503: + raise RateLimit(e) + logging.exception(f"Error on preprocessing {index}.{instruction.field} doc {doc['_id']}: {e.response.text}") + body = dict(status="error", status_code=e.response.status_code, response=e.response.text) + amcat4.index.update_document(index, doc["_id"], {instruction.field: body}) + return + result = dict(instruction.parse_output(response.json())) + result[instruction.field] = dict(status="done") + amcat4.index.update_document(index, doc["_id"], result) diff --git a/amcat4/preprocessing/task.py b/amcat4/preprocessing/task.py new file mode 100644 index 0000000..1f00712 --- /dev/null +++ b/amcat4/preprocessing/task.py @@ -0,0 +1,134 @@ +import functools +from multiprocessing import Value +from re import I +from typing import Any, Dict, List, Literal, Optional +from pydantic import BaseModel +import jsonpath_ng + +from amcat4.models import FieldType + +""" +https://huggingface.co/docs/api-inference/detailed_parameters + +""" + + +class PreprocessingRequest(BaseModel): + body: Literal["json", "binary"] + template: Optional[dict] = None + + +class PreprocessingSetting(BaseModel): + name: str + type: str = "string" + path: Optional[str] = None + + @functools.cached_property + def parsed(self) -> jsonpath_ng.JSONPath: + return jsonpath_ng.parse(self.path) + + +class PreprocessingOutput(PreprocessingSetting): + recommended_type: FieldType + + +class PreprocessingParameter(PreprocessingSetting): + use_field: Literal["yes", "no"] = "no" + default: Optional[bool | str | int | float] = None + placeholder: Optional[str] = None + header: Optional[bool] = None + secret: Optional[bool] = False + + +class PreprocessingEndpoint(BaseModel): + placeholder: str + domain: List[str] + + +class PreprocessingTask(BaseModel): + """Form for query metadata.""" + + name: str + endpoint: PreprocessingEndpoint + parameters: List[PreprocessingParameter] + outputs: List[PreprocessingOutput] + request: PreprocessingRequest + + def get_parameter(self, name) -> PreprocessingParameter: + # TODO should probably cache this + for param in self.parameters: + if param.name == name: + return param + raise ValueError(f"Parameter {name} not defined on task {self.name}") + + def get_output(self, name) -> PreprocessingOutput: + # TODO should probably cache this + for output in self.outputs: + if output.name == name: + return output + raise ValueError(f"Parameter {name} not defined on task {self.name}") + + +TASKS: List[PreprocessingTask] = [ + PreprocessingTask( + # https://huggingface.co/docs/api-inference/detailed_parameters#zero-shot-classification-task + name="HuggingFace Zero-Shot", + endpoint=PreprocessingEndpoint( + placeholder="https://api-inference.huggingface.co/models/facebook/bart-large-mnli", + domain=["huggingface.co", "huggingfacecloud.com"], + ), + parameters=[ + PreprocessingParameter(name="input", type="string", use_field="yes", path="$.inputs"), + PreprocessingParameter( + name="candidate_labels", + type="string[]", + use_field="no", + placeholder="politics, sports", + path="$.parameters.candidate_labels", + ), + PreprocessingParameter( + name="Huggingface Token", + type="string", + use_field="no", + header=True, + path="Authorization:Bearer", + secret=True, + ), + ], + outputs=[PreprocessingOutput(name="label", recommended_type="keyword", path="$.labels[0]")], + request=PreprocessingRequest(body="json", template={"inputs": "", "parameters": {"candidate_labels": ""}}), + ), + PreprocessingTask( + # https://huggingface.co/docs/api-inference/detailed_parameters#zero-shot-classification-task + name="HuggingFace Image Classification", + endpoint=PreprocessingEndpoint( + placeholder="https://api-inference.huggingface.co/models/google/vit-base-patch16-224", + domain=["huggingface.co", "huggingfacecloud.com"], + ), + parameters=[ + PreprocessingParameter(name="input", type="image", use_field="yes"), + PreprocessingParameter( + name="Huggingface Token", + type="string", + use_field="no", + header=True, + path="Authorization:Bearer", + secret=True, + ), + ], + outputs=[PreprocessingOutput(name="label", recommended_type="keyword", path="$[0].label")], + request=PreprocessingRequest(body="binary"), + ), +] + + +@functools.cache +def get_task(name): + for task in TASKS: + if task.name == name: + return task + raise ValueError(f"Task {task} not defined") + + +def get_tasks(): + return TASKS diff --git a/amcat4/query.py b/amcat4/query.py index 05ce3cf..4f63df9 100644 --- a/amcat4/query.py +++ b/amcat4/query.py @@ -1,27 +1,43 @@ """ All things query """ + from math import ceil -from re import finditer -from re import sub -from typing import Mapping, Iterable, Optional, Union, Sequence, Any, Dict, List, Tuple, Literal + +from typing import ( + Union, + Sequence, + Any, + Dict, + Tuple, + Literal, +) + +from urllib3 import Retry + +from amcat4.models import FieldSpec, FilterSpec, SortSpec from .date_mappings import mappings -from .elastic import es, update_tag_by_query +from .elastic import es +from amcat4.index import update_tag_by_query -def build_body(queries: Iterable[str] = None, filters: Mapping = None, highlight: Union[bool, dict] = False, - ids: Iterable[str] = None): - def parse_filter(field, filter) -> Tuple[Mapping, Mapping]: - filter = filter.copy() +def build_body( + queries: dict[str, str] | None = None, + filters: dict[str, FilterSpec] | None = None, + highlight: dict | None = None, + ids: list[str] | None = None, +): + def parse_filter(field: str, filterSpec: FilterSpec) -> Tuple[dict, dict]: + filter = filterSpec.model_dump(exclude_none=True) extra_runtime_mappings = {} field_filters = [] - for value in filter.pop('values', []): + for value in filter.pop("values", []): field_filters.append({"term": {field: value}}) - if 'value' in filter: - field_filters.append({"term": {field: filter.pop('value')}}) - if 'exists' in filter: - if filter.pop('exists'): + if "value" in filter: + field_filters.append({"term": {field: filter.pop("value")}}) + if "exists" in filter: + if filter.pop("exists"): field_filters.append({"exists": {"field": field}}) else: field_filters.append({"bool": {"must_not": {"exists": {"field": field}}}}) @@ -31,25 +47,28 @@ def parse_filter(field, filter) -> Tuple[Mapping, Mapping]: extra_runtime_mappings.update(mapping.mapping(field)) field_filters.append({"term": {mapping.fieldname(field): value}}) rangefilter = {} - for rangevar in ['gt', 'gte', 'lt', 'lte']: + for rangevar in ["gt", "gte", "lt", "lte"]: if rangevar in filter: rangefilter[rangevar] = filter.pop(rangevar) if rangefilter: field_filters.append({"range": {field: rangefilter}}) if filter: raise ValueError(f"Unknown filter type(s): {filter}") - return extra_runtime_mappings, {'bool': {'should': field_filters}} + return extra_runtime_mappings, {"bool": {"should": field_filters}} def parse_query(q: str) -> dict: - return {"query_string": {"query": q}} + return {"query_string": {"query": q}} - def parse_queries(qs: Sequence[str]) -> dict: + def parse_queries(queries: dict[str, str]) -> dict: + qs = queries.values() if len(qs) == 1: return parse_query(list(qs)[0]) else: return {"bool": {"should": [parse_query(q) for q in qs]}} - if not (queries or filters or ids): - return {'query': {'match_all': {}}} + + if not (queries or filters or ids or highlight): + return {"query": {"match_all": {}}} + fs, runtime_mappings = [], {} if filters: for field, filter in filters.items(): @@ -57,28 +76,30 @@ def parse_queries(qs: Sequence[str]) -> dict: fs.append(filter_term) if extra_runtime_mappings: runtime_mappings.update(extra_runtime_mappings) - if queries: - if isinstance(queries, dict): - queries = queries.values() - fs.append(parse_queries(list(queries))) + if queries is not None: + fs.append(parse_queries(queries)) if ids: fs.append({"ids": {"values": list(ids)}}) body: Dict[str, Any] = {"query": {"bool": {"filter": fs}}} if runtime_mappings: - body['runtime_mappings'] = runtime_mappings - if highlight is True: - highlight = {"number_of_fragments": 0} - elif highlight: - highlight = {**{"number_of_fragments": 0, "fragment_size": 40, "type": "plain"}, **highlight} - if highlight: - body['highlight'] = {"type": 'unified', "require_field_match": True, - "fields": {"*": highlight}} + body["runtime_mappings"] = runtime_mappings + + if highlight is not None: + body["highlight"] = highlight + return body class QueryResult: - def __init__(self, data: List[dict], - n: int = None, per_page: int = None, page: int = None, page_count: int = None, scroll_id: str = None): + def __init__( + self, + data: list[dict], + n: int | None = None, + per_page: int | None = None, + page: int | None = None, + page_count: int | None = None, + scroll_id: str | None = None, + ): if n and (page_count is None) and (per_page is not None): page_count = ceil(n / per_page) self.data = data @@ -88,33 +109,33 @@ def __init__(self, data: List[dict], self.per_page = per_page self.scroll_id = scroll_id - def as_dict(self): - meta = {"total_count": self.total_count, - "per_page": self.per_page, - "page_count": self.page_count, - } + def as_dict(self) -> dict: + meta: dict[str, int | str | None] = { + "total_count": self.total_count, + "per_page": self.per_page, + "page_count": self.page_count, + } if self.scroll_id: - meta['scroll_id'] = self.scroll_id + meta["scroll_id"] = self.scroll_id else: - meta['page'] = self.page + meta["page"] = self.page return dict(meta=meta, results=self.data) -def _normalize_queries(queries: Optional[Union[Dict[str, str], Iterable[str]]]) -> Mapping[str, str]: - if queries is None: - return {} - if isinstance(queries, dict): - return queries - return {q: q for q in queries} - - -def query_documents(index: Union[str, Sequence[str]], queries: Union[Mapping[str, str], Iterable[str]] = None, *, - page: int = 0, per_page: int = 10, - scroll=None, scroll_id: str = None, fields: Iterable[str] = None, - filters: Mapping[str, Mapping] = None, - highlight: Union[bool, dict] = False, annotations=False, - sort: List[Union[str, Mapping]] = None, - **kwargs) -> Optional[QueryResult]: +def query_documents( + index: Union[str, list[str]], + fields: list[FieldSpec] | None = None, + queries: dict[str, str] | None = None, + filters: dict[str, FilterSpec] | None = None, + sort: list[dict[str, SortSpec]] | None = None, + *, + page: int = 0, + per_page: int = 10, + scroll=None, + scroll_id: str | None = None, + highlight: bool = False, + **kwargs, +) -> QueryResult | None: """ Conduct a query_string query, returning the found documents. @@ -123,120 +144,132 @@ def query_documents(index: Union[str, Sequence[str]], queries: Union[Mapping[str If the scroll parameter is given, the result will contain a scroll_id which can be used to get the next batch. In case there are no more documents to scroll, it will return None :param index: The name of the index or indexes - :param queries: a list of queries OR a dict {label1: query1, ...} + :param fields: List of fields using the FieldSpec syntax. If not specified, only return _id. + !Any logic for determining whether a user can see the field should be done in the API layer. + :param queries: if not None, a dict with labels and queries {label1: query1, ...} + :param filters: if not None, a dict where the key is the field and the value is a FilterSpec + :param page: The number of the page to request (starting from zero) :param per_page: The number of hits per page :param scroll: if not None, will create a scroll request rather than a paginated request. Parmeter should specify the time the context should be kept alive, or True to get the default of 2m. :param scroll_id: if not None, should be a previously returned context_id to retrieve a new page of results - :param fields: if not None, specify a list of fields to retrieve for each hit - :param filters: if not None, a dict of filters with either value, values, or gte/gt/lte/lt ranges: - {field: {'values': [value1,value2], - 'value': value, - 'gte/gt/lte/lt': value, - ...}} - :param highlight: if True, add highlight tags () to all results. - If a dict, it can be used to control highlighting, e.g. to get multiple snippets - (https://www.elastic.co/guide/en/elasticsearch/reference/7.17/highlighting.html) - :param annotations: if True, get query matches as annotations. + :param highlight: if True, add tags to query matches in fields :param sort: Sort order of results, can be either a single field or a list of fields. In the list, each field is a string or a dict with options, e.g. ["id", {"date": {"order": "desc"}}] (https://www.elastic.co/guide/en/elasticsearch/reference/current/sort-search-results.html) :param kwargs: Additional elements passed to Elasticsearch.search() :return: a QueryResult, or None if there is not scroll result anymore """ + if fields is not None and not isinstance(fields, list): + raise ValueError("fields should be a list") + if scroll or scroll_id: # set scroll to default also if scroll_id is given but no scroll time is known - kwargs['scroll'] = '2m' if (not scroll or scroll is True) else scroll - queries = _normalize_queries(queries) + kwargs["scroll"] = "2m" if (not scroll or scroll is True) else scroll + if sort is not None: - kwargs["sort"] = sort + kwargs["sort"] = [] + for s in sort: + for k, v in s.items(): + kwargs["sort"].append({k: dict(v)}) + if scroll_id: result = es().scroll(scroll_id=scroll_id, **kwargs) - if not result['hits']['hits']: + if not result["hits"]["hits"]: return None else: - body = build_body(queries.values(), filters, highlight) + h = query_highlight(fields, highlight) if fields is not None else None + body = build_body(queries, filters, h) + + fieldnames = [field.name for field in fields] if fields is not None else ["_id"] + kwargs["_source"] = fieldnames - if fields: - fields = fields if isinstance(fields, list) else list(fields) - kwargs['_source'] = fields if not scroll: - kwargs['from_'] = page * per_page + kwargs["from_"] = page * per_page result = es().search(index=index, size=per_page, **body, **kwargs) data = [] - for hit in result['hits']['hits']: - hitdict = dict(_id=hit['_id'], **hit['_source']) - if annotations: - hitdict['_annotations'] = list(query_annotations(index, hit['_id'], queries)) - if 'highlight' in hit: - for key in hit['highlight'].keys(): - if hit['highlight'][key]: - hitdict[key] = " ... ".join(hit['highlight'][key]) + for hit in result["hits"]["hits"]: + hitdict = dict(_id=hit["_id"], **hit["_source"]) + hitdict = overwrite_highlight_results(hit, hitdict) + if "highlight" in hit: + for key in hit["highlight"].keys(): + if hit["highlight"][key]: + hitdict[key] = " ... ".join(hit["highlight"][key]) data.append(hitdict) if scroll_id: - return QueryResult(data, n=result['hits']['total']['value'], scroll_id=result['_scroll_id']) + return QueryResult(data, n=result["hits"]["total"]["value"], scroll_id=result["_scroll_id"]) elif scroll: - return QueryResult(data, n=result['hits']['total']['value'], per_page=per_page, scroll_id=result['_scroll_id']) + return QueryResult( + data, + n=result["hits"]["total"]["value"], + per_page=per_page, + scroll_id=result["_scroll_id"], + ) else: - return QueryResult(data, n=result['hits']['total']['value'], per_page=per_page, page=page) + return QueryResult(data, n=result["hits"]["total"]["value"], per_page=per_page, page=page) -def query_annotations(index: str, id: str, queries: Mapping[str, str]) -> Iterable[Dict]: +def query_highlight(fields: list[FieldSpec], highlight_queries: bool = False) -> dict[str, Any]: """ - get query matches in annotation format. Currently does so per hit per query. - Per hit could be optimized, but per query seems necessary: - https://stackoverflow.com/questions/44621694/elasticsearch-highlight-with-multiple-queries-not-work-as-expected + The elastic "highlight" parameters works for both highlighting text fields and adding snippets. + This function will return the highlight parameter to be added to the query body. """ - if not queries: - return - for label, query in queries.items(): - body = build_body([query], {'_id': {'value': id}}, True) - - result = es().search(index=index, body=body) - hit = result['hits']['hits'] - if len(hit) == 0: - continue - for field, highlights in hit[0]['highlight'].items(): - text = hit[0]["_source"][field] - if isinstance(text, list): - continue - for span in extract_highlight_span(text, highlights[0]): - span['variable'] = 'query' - span['value'] = label - span['field'] = field - yield span - - -def extract_highlight_span(text: str, highlight: str): - """ - It doesn't seem possible to get the offsets of highlights: - https://github.com/elastic/elasticsearch/issues/5736 + highlight: dict[str, Any] = { + "pre_tags": [""] if highlight_queries is True else [""], + "post_tags": [""] if highlight_queries is True else [""], + "require_field_match": True, + "fields": {}, + } + + for field in fields: + if field.snippet is None: + if highlight_queries is True: + # This will overwrite the field with the highlighted version, so + # only needed if highlight is True + highlight["fields"][field.name] = {"number_of_fragments": 0} + else: + # the elastic highlight feature is also used to get snippets. note that + # above in the + highlight["fields"][field.name] = { + "no_match_size": field.snippet.nomatch_chars, + "number_of_fragments": field.snippet.max_matches, + "fragment_size": field.snippet.match_chars or 1, # 0 would return the whole field + } + if field.snippet.max_matches == 0: + # If max_matches is zero, we drop the query for highlighting so that + # the nomatch_chars are returned + highlight["fields"][field.name]["highlight_query"] = {"match_all": {}} - We can get the offsets from the tags, but not yet sure how stable this is. - text is the text in the _source field. highlight should be elastics highlight if nr of fragments = 0 (i.e. full text) + return highlight + + +def overwrite_highlight_results(hit: dict, hitdict: dict): + """ + highlights are a separate field in the hits. If highlight is True, we want to overwrite + the original field with the highlighted version. If there are snippets, we want to add them """ - # elastic highlighting internally trims... - # this hack gets the offset of the trimmed text, but it's not an ideal solution - trimmed_offset = len(text) - len(text.lstrip()) - - side_by_side = ' ' - highlight = sub(side_by_side, ' ', highlight) - regex = '.+?' - tagsize = 9 # - for i, m in enumerate(finditer(regex, highlight)): - offset = trimmed_offset + m.start(0) - tagsize*i - length = len(m.group(0)) - tagsize - yield dict(offset=offset, length=length) - - -def update_tag_query(index: Union[str, Sequence[str]], action: Literal["add", "remove"], - field: str, tag: str, - queries: Union[Mapping[str, str], Iterable[str]] = None, - filters: Mapping[str, Mapping] = None, - ids: Sequence[str] = None): + if not hit.get("highlight"): + return hitdict + for key in hit["highlight"].keys(): + if hit["highlight"][key]: + hitdict[key] = " ... ".join(hit["highlight"][key]) + return hitdict + + +def update_tag_query( + index: str | list[str], + action: Literal["add", "remove"], + field: str, + tag: str, + queries: dict[str, str] | None = None, + filters: dict[str, FilterSpec] | None = None, + ids: list[str] | None = None, +): """Add or remove tags using a query""" - body = build_body(queries and queries.values(), filters, ids=ids) - update_tag_by_query(index, action, body, field, tag) + body = build_body(queries, filters, ids=ids) + + update_result = update_tag_by_query(index, action, body, field, tag) + return update_result diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..75ad82a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,42 @@ +anyio==3.7.1 +asgiref==3.6.0 +attrs==19.3.0 +Authlib==1.2.1 +certifi==2023.7.22 +cffi==1.15.1 +charset-normalizer==3.2.0 +class-doc==0.2.0b0 +click==8.1.6 +cryptography==41.0.2 +dnspython==2.4.1 +elastic-transport==8.4.0 +elasticsearch==8.8.2 +email-validator==1.3.1 +exceptiongroup==1.1.2 +fastapi==0.78.0 +h11==0.14.0 +httptools==0.6.0 +idna==3.4 +itsdangerous==2.1.2 +Jinja2==3.1.2 +MarkupSafe==2.1.3 +more-itertools==7.2.0 +orjson==3.9.2 +pycparser==2.21 +pydantic==1.9.2 +pydantic-settings==0.2.5 +python-dotenv==1.0.0 +python-multipart==0.0.5 +PyYAML==5.4.1 +requests==2.31.0 +six==1.16.0 +sniffio==1.3.0 +starlette==0.19.1 +tomlkit==0.5.11 +typing-extensions==3.10.0.2 +ujson==5.8.0 +urllib3==1.26.16 +uvicorn==0.17.6 +uvloop==0.17.0 +watchgod==0.8.2 +websockets==11.0.3 diff --git a/setup.py b/setup.py index 4c92880..26c13fe 100644 --- a/setup.py +++ b/setup.py @@ -30,14 +30,21 @@ "uvicorn", "requests", "class_doc", + "mypy", + "minio", + "jsonpath_ng", ], extras_require={ "dev": [ "pytest", + "pytest-httpx", "mypy", "flake8", "responses", "pre-commit", + "types-requests", + "pytest-asyncio", + "pytest-minio-mock @ git+ssh://git@github.com/vanatteveldt/pytest-minio-mock.git", ] }, entry_points={"console_scripts": ["amcat4 = amcat4.__main__:main"]}, diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py index bd651af..ea97252 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,12 @@ -from typing import Iterable - +import logging +from typing import Any, AsyncGenerator, AsyncIterable import pytest +import pytest_asyncio import responses from fastapi.testclient import TestClient +from httpx import ASGITransport, AsyncClient -from amcat4 import elastic, api # noqa: E402 +from amcat4 import api, multimedia # noqa: E402 from amcat4.config import get_settings, AuthOptions from amcat4.elastic import es from amcat4.index import ( @@ -15,7 +17,9 @@ delete_user, remove_global_role, set_global_role, + upload_documents, ) +from amcat4.models import CreateField, FieldType from tests.middlecat_keypair import PUBLIC_KEY UNITS = [ @@ -33,9 +37,7 @@ def mock_middlecat(): get_settings().middlecat_url = "http://localhost:5000" get_settings().host = "http://localhost:3000" with responses.RequestsMock(assert_all_requests_are_fired=False) as resp: - resp.get( - "http://localhost:5000/api/configuration", json={"public_key": PUBLIC_KEY} - ) + resp.get("http://localhost:5000/api/configuration", json={"public_key": PUBLIC_KEY}) yield None @@ -137,39 +139,19 @@ def guest_index(): delete_index(index, ignore_missing=True) -def upload(index: str, docs: Iterable[dict], **kwargs): +def upload(index: str, docs: list[dict[str, Any]], fields: dict[str, FieldType | CreateField] | None = None): """ Upload these docs to the index, giving them an incremental id, and flush """ - ids = [] - for i, doc in enumerate(docs): - id = str(i) - ids.append(id) - defaults = {"title": "title", "date": "2018-01-01", "text": "text", "_id": id} - for k, v in defaults.items(): - if k not in doc: - doc[k] = v - elastic.upload_documents(index, docs, **kwargs) + res = upload_documents(index, docs, fields) refresh_index(index) - return ids TEST_DOCUMENTS = [ + {"_id": 0, "cat": "a", "subcat": "x", "i": 1, "date": "2018-01-01", "text": "this is a text", "title": "title"}, + {"_id": 1, "cat": "a", "subcat": "x", "i": 2, "date": "2018-02-01", "text": "a test text", "title": "title"}, { - "cat": "a", - "subcat": "x", - "i": 1, - "date": "2018-01-01", - "text": "this is a text", - }, - { - "cat": "a", - "subcat": "x", - "i": 2, - "date": "2018-02-01", - "text": "a test text", - }, - { + "_id": 2, "cat": "a", "subcat": "y", "i": 11, @@ -178,6 +160,7 @@ def upload(index: str, docs: Iterable[dict], **kwargs): "title": "bla", }, { + "_id": 3, "cat": "b", "subcat": "y", "i": 31, @@ -189,10 +172,18 @@ def upload(index: str, docs: Iterable[dict], **kwargs): def populate_index(index): + upload( index, TEST_DOCUMENTS, - fields={"cat": "keyword", "subcat": "keyword", "i": "long"}, + fields={ + "text": "text", + "title": "text", + "date": "date", + "cat": "keyword", + "subcat": "keyword", + "i": "integer", + }, ) return TEST_DOCUMENTS @@ -214,10 +205,8 @@ def index_many(): create_index(index, guest_role=Role.READER) upload( index, - [ - dict(id=i, pagenr=abs(10 - i), text=text) - for (i, text) in enumerate(["odd", "even"] * 10) - ], + [dict(id=i, pagenr=abs(10 - i), text=text) for (i, text) in enumerate(["odd", "even"] * 10)], + fields={"id": "integer", "pagenr": "integer", "text": "text"}, ) yield index delete_index(index, ignore_missing=True) @@ -226,3 +215,21 @@ def index_many(): @pytest.fixture() def app(): return api.app + + +@pytest.fixture() +def minio(minio_mock): + from minio.deleteobjects import DeleteObject + + minio = multimedia.get_minio() + for bucket in minio.list_buckets(): + for x in minio.list_objects(bucket.name, recursive=True): + minio.remove_object(x.bucket_name, x.object_name or "") + minio.remove_bucket(bucket.name) + + +@pytest_asyncio.fixture +async def aclient(app) -> AsyncIterable[AsyncClient]: + host = get_settings().host + async with AsyncClient(transport=ASGITransport(app=app), base_url=host) as c: + yield c diff --git a/tests/test_aggregate.py b/tests/test_aggregate.py index 8d94590..7ec3ee8 100644 --- a/tests/test_aggregate.py +++ b/tests/test_aggregate.py @@ -2,17 +2,21 @@ from datetime import datetime, date from amcat4.aggregate import query_aggregate, Axis, Aggregation +from amcat4.api.query import _standardize_queries +from amcat4.models import CreateField, Field from tests.conftest import upload from tests.tools import dictset -def do_query(index: str, *axes, **kargs): +def do_query(index: str, *args, **kwargs): def _key(x): if len(x) == 1: return x[0] return x - axes = [Axis(x) if isinstance(x, str) else x for x in axes] - result = query_aggregate(index, axes, **kargs) + + axes = [Axis(x) if isinstance(x, str) else x for x in args] + + result = query_aggregate(index, axes, **kwargs) return {_key(vals[:-1]): vals[-1] for vals in result.data} @@ -27,80 +31,100 @@ def _y(y): def test_aggregate(index_docs): q = functools.partial(do_query, index_docs) assert q(Axis("cat")) == {"a": 3, "b": 1} - assert q(Axis(field="date")) == {_d('2018-01-01'): 2, _d('2018-02-01'): 1, _d('2020-01-01'): 1} + assert q(Axis(field="date")) == {_d("2018-01-01"): 2, _d("2018-02-01"): 1, _d("2020-01-01"): 1} def test_aggregate_querystring(index_docs): q = functools.partial(do_query, index_docs) - assert q("cat", queries=['toto']) == {"a": 1, "b": 1} - assert q("cat", queries=['test*']) == {"a": 2, "b": 1} - assert q("cat", queries=['"a text"', 'another']) == {"a": 2} + assert q("cat", queries=_standardize_queries(["toto"])) == {"a": 1, "b": 1} + assert q("cat", queries=_standardize_queries(["test*"])) == {"a": 2, "b": 1} + assert q("cat", queries=_standardize_queries(['"a text"', "another"])) == {"a": 2} def test_interval(index_docs): q = functools.partial(do_query, index_docs) assert q(Axis(field="date", interval="year")) == {_y(2018): 3, _y(2020): 1} - assert q(Axis(field="i", interval="10")) == {0.: 2, 10.: 1, 30.: 1} + assert q(Axis(field="i", interval="10")) == {0.0: 2, 10.0: 1, 30.0: 1} def test_second_axis(index_docs): q = functools.partial(do_query, index_docs) - assert q("cat", 'subcat') == {("a", "x"): 2, ("a", "y"): 1, ("b", "y"): 1} - assert (q(Axis(field="date", interval="year"), 'cat') - == {(_y(2018), "a"): 2, (_y(2020), "a"): 1, (_y(2018), "b"): 1}) - assert (q('cat', Axis(field="date", interval="year")) - == {("a", _y(2018)): 2, ("a", _y(2020)): 1, ("b", _y(2018)): 1}) + assert q("cat", "subcat") == {("a", "x"): 2, ("a", "y"): 1, ("b", "y"): 1} + assert q(Axis(field="date", interval="year"), "cat") == {(_y(2018), "a"): 2, (_y(2020), "a"): 1, (_y(2018), "b"): 1} + assert q("cat", Axis(field="date", interval="year")) == {("a", _y(2018)): 2, ("a", _y(2020)): 1, ("b", _y(2018)): 1} def test_count(index_docs): """Does aggregation without axes work""" assert do_query(index_docs) == {(): 4} - assert do_query(index_docs, queries=["text"]) == {(): 2} + assert do_query(index_docs, queries={"text": "text"}) == {(): 2} def test_byquery(index_docs): """Get number of documents per query""" - assert do_query(index_docs, Axis("_query"), queries=["text", "test*"]) == {"text": 2, "test*": 3} - assert (do_query(index_docs, Axis("_query"), Axis("subcat"), queries=["text", "test*"]) == - {("text", "x"): 2, ("test*", "x"): 1, ("test*", "y"): 2}) - assert (do_query(index_docs, Axis("subcat"), Axis("_query"), queries=["text", "test*"]) == - {("x", "text"): 2, ("x", "test*"): 1, ("y", "test*"): 2}) + assert do_query(index_docs, Axis("_query"), queries={"text": "text", "test*": "test*"}) == {"text": 2, "test*": 3} + assert do_query(index_docs, Axis("_query"), Axis("subcat"), queries={"text": "text", "test*": "test*"}) == { + ("text", "x"): 2, + ("test*", "x"): 1, + ("test*", "y"): 2, + } + assert do_query(index_docs, Axis("subcat"), Axis("_query"), queries={"text": "text", "test*": "test*"}) == { + ("x", "text"): 2, + ("x", "test*"): 1, + ("y", "test*"): 2, + } def test_metric(index_docs: str): """Do metric aggregations (e.g. avg(x)) work?""" + # Single and double aggregation with axis def q(axes, aggregations): return dictset(query_aggregate(index_docs, axes, aggregations).as_dicts()) - assert (q([Axis("subcat")], [Aggregation("i", "avg")]) == - dictset([{"subcat": "x", "n": 2, "avg_i": 1.5}, {"subcat": "y", "n": 2, "avg_i": 21.0}])) - assert (q([Axis("subcat")], [Aggregation("i", "avg"), Aggregation("i", "max")]) == - dictset([{"subcat": "x", "n": 2, "avg_i": 1.5, "max_i": 2.0}, - {"subcat": "y", "n": 2, "avg_i": 21.0, "max_i": 31.0}])) + + assert q([Axis("subcat")], [Aggregation("i", "avg")]) == dictset( + [{"subcat": "x", "n": 2, "avg_i": 1.5}, {"subcat": "y", "n": 2, "avg_i": 21.0}] + ) + assert q([Axis("subcat")], [Aggregation("i", "avg"), Aggregation("i", "max")]) == dictset( + [{"subcat": "x", "n": 2, "avg_i": 1.5, "max_i": 2.0}, {"subcat": "y", "n": 2, "avg_i": 21.0, "max_i": 31.0}] + ) # Aggregation only - assert (q(None, [Aggregation("i", "avg")]) == dictset([{"n": 4, "avg_i": 11.25}])) - assert (q(None, [Aggregation("i", "avg"), Aggregation("i", "max")]) == - dictset([{"n": 4, "avg_i": 11.25, "max_i": 31.0}])) + assert q(None, [Aggregation("i", "avg")]) == dictset([{"n": 4, "avg_i": 11.25}]) + assert q(None, [Aggregation("i", "avg"), Aggregation("i", "max")]) == dictset([{"n": 4, "avg_i": 11.25, "max_i": 31.0}]) + + # Count only + assert q([], []) == dictset([{"n": 4}]) # Check value handling - Aggregation on date fields - assert (q(None, [Aggregation("date", "max")]) == dictset([{"n": 4, "max_date": "2020-01-01T00:00:00"}])) - assert (q([Axis("subcat")], [Aggregation("date", "avg")]) == - dictset([{"subcat": "x", "n": 2, "avg_date": "2018-01-16T12:00:00"}, - {"subcat": "y", "n": 2, "avg_date": "2019-01-01T00:00:00"}])) + assert q(None, [Aggregation("date", "max")]) == dictset([{"n": 4, "max_date": "2020-01-01T00:00:00"}]) + assert q([Axis("subcat")], [Aggregation("date", "avg")]) == dictset( + [ + {"subcat": "x", "n": 2, "avg_date": "2018-01-16T12:00:00"}, + {"subcat": "y", "n": 2, "avg_date": "2019-01-01T00:00:00"}, + ] + ) def test_aggregate_datefunctions(index: str): q = functools.partial(do_query, index) - docs = [dict(date=x) for x in ["2018-01-01T04:00:00", # monday night - "2018-01-01T09:00:00", # monday morning - "2018-01-11T09:00:00", # thursday morning - "2018-01-17T11:00:00", # wednesday morning - "2018-01-17T18:00:00", # wednesday evening - "2018-03-07T23:59:00", # wednesday evening - ]] - upload(index, docs) - assert q(Axis("date", interval="day")) == {date(2018, 1, 1): 2, date(2018, 1, 11): 1, - date(2018, 1, 17): 2, date(2018, 3, 7): 1} + docs = [ + dict(date=x) + for x in [ + "2018-01-01T04:00:00", # monday night + "2018-01-01T09:00:00", # monday morning + "2018-01-11T09:00:00", # thursday morning + "2018-01-17T11:00:00", # wednesday morning + "2018-01-17T18:00:00", # wednesday evening + "2018-03-07T23:59:00", # wednesday evening + ] + ] + upload(index, docs, fields=dict(date=CreateField(type="date"))) + assert q(Axis("date", interval="day")) == { + date(2018, 1, 1): 2, + date(2018, 1, 11): 1, + date(2018, 1, 17): 2, + date(2018, 3, 7): 1, + } assert q(Axis("date", interval="dayofweek")) == {"Monday": 2, "Wednesday": 3, "Thursday": 1} assert q(Axis("date", interval="daypart")) == {"Night": 1, "Morning": 3, "Evening": 2} assert q(Axis("date", interval="monthnr")) == {1: 5, 3: 1} @@ -108,4 +132,8 @@ def test_aggregate_datefunctions(index: str): assert q(Axis("date", interval="dayofmonth")) == {1: 2, 11: 1, 17: 2, 7: 1} assert q(Axis("date", interval="weeknr")) == {1: 2, 2: 1, 3: 2, 10: 1} assert q(Axis("date", interval="month"), Axis("date", interval="dayofmonth")) == { - (date(2018, 1, 1), 1): 2, (date(2018, 1, 1), 11): 1, (date(2018, 1, 1), 17): 2, (date(2018, 3, 1), 7): 1} + (date(2018, 1, 1), 1): 2, + (date(2018, 1, 1), 11): 1, + (date(2018, 1, 1), 17): 2, + (date(2018, 3, 1), 7): 1, + } diff --git a/tests/test_api_documents.py b/tests/test_api_documents.py index 39f4bde..612a219 100644 --- a/tests/test_api_documents.py +++ b/tests/test_api_documents.py @@ -1,5 +1,4 @@ from amcat4.index import set_role, Role -from tests.conftest import populate_index from tests.tools import post_json, build_headers, get_json, check @@ -8,22 +7,16 @@ def test_documents_unauthorized(client, index, user): docs = {"documents": []} check(client.post(f"index/{index}/documents", json=docs), 401) check( - client.post( - f"index/{index}/documents", json=docs, headers=build_headers(user=user) - ), + client.post(f"index/{index}/documents", json=docs, headers=build_headers(user=user)), 401, ) check(client.put(f"index/{index}/documents/1", json={}), 401) check( - client.put( - f"index/{index}/documents/1", json={}, headers=build_headers(user=user) - ), + client.put(f"index/{index}/documents/1", json={}, headers=build_headers(user=user)), 401, ) check(client.get(f"index/{index}/documents/1"), 401) - check( - client.get(f"index/{index}/documents/1", headers=build_headers(user=user)), 401 - ) + check(client.get(f"index/{index}/documents/1", headers=build_headers(user=user)), 401) def test_documents(client, index, user): @@ -34,9 +27,12 @@ def test_documents(client, index, user): f"index/{index}/documents", user=user, json={ - "documents": [ - {"_id": "id", "title": "a title", "text": "text", "date": "2020-01-01"} - ] + "documents": [{"_id": "id", "title": "a title", "text": "text", "date": "2020-01-01"}], + "fields": { + "title": {"type": "text"}, + "text": {"type": "text"}, + "date": {"type": "date"}, + }, }, ) url = f"index/{index}/documents/id" @@ -48,56 +44,3 @@ def test_documents(client, index, user): assert get_json(client, url, user=user)["title"] == "the headline" check(client.delete(url, headers=build_headers(user)), 204) check(client.get(url, headers=build_headers(user)), 404) - - -def test_metareader(client, index, index_docs, user, reader): - set_role(index, user, Role.METAREADER) - set_role(index, reader, Role.READER) - populate_index(index) - - r = get_json( - client, - f"/index/{index}/documents?fields=title", - headers=build_headers(user), - ) - _id = r["results"][0]["_id"] - url = f"index/{index}/documents/{_id}" - # Metareader should not be able to retrieve document source - check(client.get(url, headers=build_headers(user)), 401) - check(client.get(url, headers=build_headers(reader)), 200) - - def get_join(x): - return ",".join(x) if isinstance(x, list) else x - - # Metareader should not be able to query text (including highlight) - for ix, u, fields, highlight, outcome in [ - (index, user, ["text"], False, 401), - (index_docs, user, ["text"], False, 200), - ([index_docs, index], user, ["text"], False, 401), - (index, user, ["text", "title"], False, 401), - (index, user, ["title"], False, 200), - (index, reader, ["text"], False, 200), - ([index_docs, index], reader, ["text"], False, 200), - (index, user, ["title"], True, 401), - (index, reader, ["title"], True, 200), - ]: - check( - client.get( - f"/index/{get_join(ix)}/documents?fields={get_join(fields)}{'&highlight=true' if highlight else ''}", - headers=build_headers(u), - ), - outcome, - msg=f"Index: {ix}, user: {u}, fields: {fields}", - ) - body = {"fields": fields} - if highlight: - body["highlight"] = True - check( - client.post( - f"/index/{get_join(ix)}/query", - headers=build_headers(u), - json=body, - ), - outcome, - msg=f"Index: {ix}, user: {u}, fields: {fields}", - ) diff --git a/tests/test_api_errors.py b/tests/test_api_errors.py index 5fac4b9..063d4af 100644 --- a/tests/test_api_errors.py +++ b/tests/test_api_errors.py @@ -12,16 +12,20 @@ def check(client, url, status, message, method="post", user=None, **kargs): raise AssertionError(f"Status {r.status_code} error {repr(r.text)} does not match pattern {repr(message)}") -def test_documents_unauthorized(client, index, writer, ): +def test_documents_unauthorized( + client, + index, + writer, +): check(client, "/index/", 401, "global writer permissions") - check(client, f"/index/{index}/", 401, f"permissions on index {index}", method='get') + check(client, f"/index/{index}/", 401, f"permissions on index {index}", method="get") def test_error_elastic(client, index, admin): for hostname in ("doesnotexist.example.com", "https://doesnotexist.example.com:9200"): - with amcat_settings(elastic_host=hostname): + with amcat_settings(elastic_host=hostname, elastic_verify_ssl=True): es.cache_clear() - check(client, f"/index/{index}/", 500, f"cannot connect.*{hostname}", method='get', user=admin) + check(client, f"/index/{index}/", 500, f"cannot connect.*{hostname}", method="get", user=admin) def test_error_index_create(client, writer, index): diff --git a/tests/test_api_index.py b/tests/test_api_index.py index 2a5b562..a285f99 100644 --- a/tests/test_api_index.py +++ b/tests/test_api_index.py @@ -1,7 +1,9 @@ from starlette.testclient import TestClient from amcat4 import elastic -from amcat4.index import get_guest_role, Role, set_guest_role, set_role, remove_role + +from amcat4.index import GuestRole, get_guest_role, Role, set_guest_role, set_role, remove_role +from amcat4.fields import update_fields from tests.tools import build_headers, post_json, get_json, check, refresh @@ -18,7 +20,7 @@ def test_create_list_delete_index(client, index_name, user, writer, writer2, adm # Writers can create indices post_json(client, "/index/", user=writer, json=dict(id=index_name)) refresh() - assert index_name in {x["name"] for x in get_json(client, "/index/", user=writer)} + assert index_name in {x["name"] for x in get_json(client, "/index/", user=writer) or []} # Users can GET their own index, global writer can GET all indices, others cannot GET non-public indices check(client.get(f"/index/{index_name}"), 401) @@ -27,11 +29,9 @@ def test_create_list_delete_index(client, index_name, user, writer, writer2, adm check(client.get(f"/index/{index_name}", headers=build_headers(user=writer2)), 200) # Users can only see indices that they have a role in or that have a guest role - assert index_name not in {x["name"] for x in get_json(client, "/index/", user=user)} - assert index_name not in { - x["name"] for x in get_json(client, "/index/", user=writer2) - } - assert index_name in {x["name"] for x in get_json(client, "/index/", user=writer)} + assert index_name not in {x["name"] for x in get_json(client, "/index/", user=user) or []} + assert index_name not in {x["name"] for x in get_json(client, "/index/", user=writer2) or []} + assert index_name in {x["name"] for x in get_json(client, "/index/", user=writer) or []} # (Only) index admin can change index guest role check(client.put(f"/index/{index_name}", json={"guest_role": "METAREADER"}), 401) @@ -62,7 +62,7 @@ def test_create_list_delete_index(client, index_name, user, writer, writer2, adm assert get_guest_role(index_name).name == "READER" # Index should now be visible to non-authorized users - assert index_name in {x["name"] for x in get_json(client, "/index/", user=writer)} + assert index_name in {x["name"] for x in get_json(client, "/index/", user=writer) or []} check(client.get(f"/index/{index_name}", headers=build_headers(user=user)), 200) @@ -79,55 +79,52 @@ def test_fields_upload(client: TestClient, user: str, index: str): } for i, x in enumerate(["a", "a", "b"]) ], - "columns": {"x": "keyword"}, + "fields": { + "title": "text", + "text": "text", + "date": "date", + "x": "keyword", + }, } # You need METAREADER permissions to read fields, and WRITER to upload docs check(client.get(f"/index/{index}/fields"), 401) check( - client.post( - f"/index/{index}/documents", headers=build_headers(user), json=body - ), + client.post(f"/index/{index}/documents", headers=build_headers(user), json=body), 401, ) set_role(index, user, Role.METAREADER) - fields = get_json(client, f"/index/{index}/fields", user=user) - assert set(fields.keys()) == {"title", "date", "text", "url"} - assert fields["date"]["type"] == "date" + + ## can get fields + fields = get_json(client, f"/index/{index}/fields", user=user) or {} + ## but should still be empty, since no fields were created + assert len(set(fields.keys())) == 0 check( - client.post( - f"/index/{index}/documents", headers=build_headers(user), json=body - ), + client.post(f"/index/{index}/documents", headers=build_headers(user), json=body), 401, ) set_role(index, user, Role.WRITER) post_json(client, f"/index/{index}/documents", user=user, json=body) get_json(client, f"/index/{index}/refresh", expected=204) - doc = get_json(client, f"/index/{index}/documents/0", user=user) + doc = get_json(client, f"/index/{index}/documents/0", user=user) or {} assert set(doc.keys()) == {"date", "text", "title", "x"} assert doc["title"] == "doc 0" # field selection - assert set( - get_json( - client, f"/index/{index}/documents/0", user=user, params={"fields": "title"} - ).keys() - ) == {"title"} - assert ( - get_json(client, f"/index/{index}/fields", user=user)["x"]["type"] == "keyword" - ) + assert set((get_json(client, f"/index/{index}/documents/0", user=user, params={"fields": "title"}) or {}).keys()) == { + "title" + } + assert (get_json(client, f"/index/{index}/fields", user=user) or {})["x"]["type"] == "keyword" elastic.es().indices.refresh() - assert set(get_json(client, f"/index/{index}/fields/x/values", user=user)) == { + assert set(get_json(client, f"/index/{index}/fields/x/values", user=user) or []) == { "a", "b", } -def test_set_get_delete_roles( - client: TestClient, admin: str, writer: str, user: str, index: str -): +def test_set_get_delete_roles(client: TestClient, admin: str, writer: str, user: str, index: str): body = {"email": user, "role": "READER"} # Anon, unauthorized; READER can't add users check(client.post(f"/index/{index}/users", json=body), 401) @@ -159,15 +156,10 @@ def test_set_get_delete_roles( json={"email": writer, "role": "WRITER"}, user=admin, ) - assert get_json(client, f"/index/{index}/users", user=writer) == [ - {"email": writer, "role": "WRITER"} - ] + assert get_json(client, f"/index/{index}/users", user=writer) == [{"email": writer, "role": "WRITER"}] # Writer can now add a new user post_json(client, f"/index/{index}/users", json=body, user=writer) - users = { - u["email"]: u["role"] - for u in get_json(client, f"/index/{index}/users", user=writer) - } + users = {u["email"]: u["role"] for u in get_json(client, f"/index/{index}/users", user=writer) or []} assert users == {writer: "WRITER", user: "READER"} # Anon, unauthorized or READER can't change users @@ -183,15 +175,10 @@ def test_set_get_delete_roles( client.put(user_url, json={"role": "WRITER"}, headers=build_headers(writer)), 200, ) - users = { - u["email"]: u["role"] - for u in get_json(client, f"/index/{index}/users", user=writer) - } + users = {u["email"]: u["role"] for u in get_json(client, f"/index/{index}/users", user=writer) or []} assert users == {writer: "WRITER", user: "WRITER"} # Writer can't change to admin - check( - client.put(writer_url, json={"role": "ADMIN"}, headers=build_headers(user)), 401 - ) + check(client.put(writer_url, json={"role": "ADMIN"}, headers=build_headers(user)), 401) # Writer can't change from admin set_role(index, writer, Role.ADMIN) check( @@ -217,18 +204,14 @@ def test_name_description(client, index, index_name, user, admin): check(client.put(f"/index/{index}", json=dict(name="test")), 401) check(client.get(f"/index/{index}"), 401) check( - client.put( - f"/index/{index}", json=dict(name="test"), headers=build_headers(user) - ), + client.put(f"/index/{index}", json=dict(name="test"), headers=build_headers(user)), 401, ) check(client.get(f"/index/{index}", headers=build_headers(user)), 401) # global admin and index writer can change details check( - client.put( - f"/index/{index}", json=dict(name="test"), headers=build_headers(admin) - ), + client.put(f"/index/{index}", json=dict(name="test"), headers=build_headers(admin)), 200, ) set_role(index, user, Role.ADMIN) @@ -242,14 +225,14 @@ def test_name_description(client, index, index_name, user, admin): ) # global admin and index or guest metareader can read details - assert get_json(client, f"/index/{index}", user=admin)["description"] == "ooktest" - assert get_json(client, f"/index/{index}", user=user)["name"] == "test" + assert (get_json(client, f"/index/{index}", user=admin) or {})["description"] == "ooktest" + assert (get_json(client, f"/index/{index}", user=user) or {})["name"] == "test" set_role(index, user, Role.METAREADER) - assert get_json(client, f"/index/{index}", user=user)["name"] == "test" + assert (get_json(client, f"/index/{index}", user=user) or {})["name"] == "test" set_role(index, user, None) check(client.get(f"/index/{index}", headers=build_headers(user)), 401) - set_guest_role(index, Role.METAREADER) - assert get_json(client, f"/index/{index}", user=user)["name"] == "test" + set_guest_role(index, GuestRole.METAREADER) + assert (get_json(client, f"/index/{index}", user=user) or {})["name"] == "test" check( client.post( @@ -257,32 +240,16 @@ def test_name_description(client, index, index_name, user, admin): json=dict( id=index_name, description="test2", - guest_role="metareader", - summary_field="party", + guest_role="METAREADER", ), headers=build_headers(admin), ), 201, ) - assert get_json(client, f"/index/{index_name}", user=user)["description"] == "test2" + assert (get_json(client, f"/index/{index_name}", user=user) or {})["description"] == "test2" # name and description should be present in list of indices refresh() - indices = {ix["id"]: ix for ix in get_json(client, "/index")} + indices = {ix["id"]: ix for ix in get_json(client, "/index") or []} assert indices[index]["description"] == "ooktest" assert indices[index_name]["description"] == "test2" - - # can set and get summary field - elastic.set_fields(index_name, {"party": "keyword"}) - refresh() - check( - client.put( - f"/index/{index_name}", - json=dict(summary_field="party"), - headers=build_headers(admin), - ), - 200, - ) - assert ( - get_json(client, f"/index/{index_name}", user=admin)["summary_field"] == "party" - ) diff --git a/tests/test_api_metareader.py b/tests/test_api_metareader.py new file mode 100644 index 0000000..b8dda4d --- /dev/null +++ b/tests/test_api_metareader.py @@ -0,0 +1,89 @@ +from fastapi.testclient import TestClient +from amcat4.models import FieldSpec, SnippetParams + +from tests.tools import build_headers, post_json + + +def create_index_metareader(client, index, admin): + # Create new user and set index role to metareader + client.post("/users", headers=build_headers(admin), json={"email": "meta@reader.com", "role": "METAREADER"}), + client.put(f"/index/{index}/users/meta@reader.com", headers=build_headers(admin), json={"role": "METAREADER"}), + + +def set_metareader_access(client, index, admin, metareader): + client.put( + f"/index/{index}/fields", + headers=build_headers(admin), + json={"text": {"metareader": metareader}}, + ) + + +def check_allowed(client, index: str, field: FieldSpec, allowed=True): + post_json( + client, + f"/index/{index}/query", + user="meta@reader.com", + expected=200 if allowed else 401, + json={"fields": [field.model_dump()]}, + ) + + +def test_metareader_none(client: TestClient, admin, index_docs): + """ + Set text field to metareader_access=none + Metareader should not be able to get field both full and as snippet + """ + create_index_metareader(client, index_docs, admin) + set_metareader_access(client, index_docs, admin, {"access": "none"}) + + full = FieldSpec(name="text") + snippet = FieldSpec(name="text", snippet=SnippetParams(nomatch_chars=150, max_matches=3, match_chars=50)) + + check_allowed(client, index_docs, full, allowed=False) + check_allowed(client, index_docs, field=snippet, allowed=False) + + +def test_metareader_read(client: TestClient, admin, index_docs): + """ + Set text field to metareader_access=read + Metareader should be able to get field both full and as snippet + """ + create_index_metareader(client, index_docs, admin) + set_metareader_access(client, index_docs, admin, {"access": "read"}) + + full = FieldSpec(name="text") + snippet = FieldSpec(name="text", snippet=SnippetParams(nomatch_chars=150, max_matches=3, match_chars=50)) + + check_allowed(client, index_docs, field=full, allowed=True) + check_allowed(client, index_docs, field=snippet, allowed=True) + + +def test_metareader_snippet(client: TestClient, admin, index_docs): + """ + Set text field to metareader_access=snippet[50;1;20] + Metareader should only be able to get field as snippet + with maximum parameters of nomatch_chars=50, max_matches=1, match_chars=20 + """ + create_index_metareader(client, index_docs, admin) + set_metareader_access( + client, + index_docs, + admin, + {"access": "snippet", "max_snippet": {"nomatch_chars": 50, "max_matches": 1, "match_chars": 20}}, + ) + + full = FieldSpec(name="text") + snippet_too_long = FieldSpec(name="text", snippet=SnippetParams(nomatch_chars=51, max_matches=1, match_chars=20)) + snippet_too_many_matches = FieldSpec(name="text", snippet=SnippetParams(nomatch_chars=50, max_matches=2, match_chars=20)) + snippet_too_long_matches = FieldSpec(name="text", snippet=SnippetParams(nomatch_chars=50, max_matches=1, match_chars=21)) + + snippet_just_right = FieldSpec(name="text", snippet=SnippetParams(nomatch_chars=50, max_matches=1, match_chars=20)) + snippet_less_than_allowed = FieldSpec(name="text", snippet=SnippetParams(nomatch_chars=49, max_matches=0, match_chars=19)) + + check_allowed(client, index_docs, field=full, allowed=False) + check_allowed(client, index_docs, field=snippet_too_long, allowed=False) + check_allowed(client, index_docs, field=snippet_too_many_matches, allowed=False) + check_allowed(client, index_docs, field=snippet_too_long_matches, allowed=False) + + check_allowed(client, index_docs, field=snippet_just_right, allowed=True) + check_allowed(client, index_docs, field=snippet_less_than_allowed, allowed=True) diff --git a/tests/test_api_multimedia.py b/tests/test_api_multimedia.py new file mode 100644 index 0000000..ce28495 --- /dev/null +++ b/tests/test_api_multimedia.py @@ -0,0 +1,74 @@ +from fastapi.testclient import TestClient +import pytest +import requests +from amcat4 import multimedia +from amcat4.index import set_role, Role +from tests.tools import post_json, build_headers, get_json, check + + +def _get_names(client: TestClient, index, user, **kargs): + res = client.get(f"index/{index}/multimedia/list", params=kargs, headers=build_headers(user)) + res.raise_for_status() + return {obj["key"] for obj in res.json()} + + +def test_authorisation(minio, client, index, user, reader): + check(client.get(f"index/{index}/multimedia/list"), 401) + check(client.get(f"index/{index}/multimedia/presigned_get", params=dict(key="")), 401) + check(client.get(f"index/{index}/multimedia/presigned_post"), 401) + + set_role(index, user, Role.METAREADER) + set_role(index, reader, Role.READER) + check(client.get(f"index/{index}/multimedia/list", headers=build_headers(user)), 401) + check(client.get(f"index/{index}/multimedia/presigned_get", params=dict(key=""), headers=build_headers(user)), 401) + check(client.get(f"index/{index}/multimedia/presigned_post", headers=build_headers(reader)), 401) + + +def test_post_get_list(minio, client, index, user): + pytest.skip("mock minio does not allow presigned post, skipping for now") + set_role(index, user, Role.WRITER) + assert _get_names(client, index, user) == set() + post = client.get(f"index/{index}/multimedia/presigned_post", headers=build_headers(user)).json() + assert set(post.keys()) == {"url", "form_data"} + multimedia.add_multimedia_object(index, "test", b"bytes") + assert _get_names(client, index, user) == {"test"} + res = client.get(f"index/{index}/multimedia/presigned_get", headers=build_headers(user), params=dict(key="test")) + res.raise_for_status() + assert requests.get(res.json()["url"]).content == b"bytes" + + +def test_list_options(minio, client, index, reader): + set_role(index, reader, Role.READER) + multimedia.add_multimedia_object(index, "myfolder/a1", b"a1") + multimedia.add_multimedia_object(index, "myfolder/a2", b"a2") + multimedia.add_multimedia_object(index, "obj1", b"obj1") + multimedia.add_multimedia_object(index, "obj2", b"obj2") + multimedia.add_multimedia_object(index, "obj3", b"obj3") + multimedia.add_multimedia_object(index, "zzz", b"zzz") + + assert _get_names(client, index, reader) == {"obj1", "obj2", "obj3", "myfolder/", "zzz"} + assert _get_names(client, index, reader, recursive=True) == {"obj1", "obj2", "obj3", "myfolder/a1", "myfolder/a2", "zzz"} + assert _get_names(client, index, reader, prefix="obj") == {"obj1", "obj2", "obj3"} + assert _get_names(client, index, reader, prefix="myfolder/") == {"myfolder/a1", "myfolder/a2"} + assert _get_names(client, index, reader, prefix="myfolder/", presigned_get=True) == {"myfolder/a1", "myfolder/a2"} + res = client.get( + f"index/{index}/multimedia/list", params=dict(prefix="myfolder/", presigned_get=True), headers=build_headers(reader) + ) + res.raise_for_status() + assert all("presigned_get" in o for o in res.json()) + + +def test_list_pagination(minio, client, index, reader): + set_role(index, reader, Role.READER) + ids = [f"obj_{i:02}" for i in range(15)] + for id in ids: + multimedia.add_multimedia_object(index, id, id.encode("utf-8")) + + # default page size is 10 + names = _get_names(client, index, reader) + assert names == set(ids[:10]) + more_names = _get_names(client, index, reader, start_after=ids[9]) + assert more_names == set(ids[10:]) + + names = _get_names(client, index, reader, n=5) + assert names == set(ids[:5]) diff --git a/tests/test_api_pagination.py b/tests/test_api_pagination.py index 62c524d..747a1c8 100644 --- a/tests/test_api_pagination.py +++ b/tests/test_api_pagination.py @@ -1,66 +1,56 @@ from amcat4.index import Role, set_role +from amcat4.models import CreateField from tests.conftest import upload from tests.tools import get_json, post_json def test_pagination(client, index, user): """Does basic pagination work?""" - set_role(index, user, Role.METAREADER) + set_role(index, user, Role.READER) + + # TODO. Tests are not independent. test_pagination fails if run directly after other tests. + # Probably delete_index doesn't fully delete + + upload(index, docs=[{"i": i} for i in range(66)], fields={"i": "integer"}) + url = f"/index/{index}/query" + r = post_json(client, url, user=user, json={"sort": "i", "per_page": 20, "fields": ["i"]}, expected=200) - upload(index, docs=[{"i": i} for i in range(66)]) - url = f"/index/{index}/documents" - r = get_json( - client, url, user=user, params={"sort": "i", "per_page": 20, "fields": ["i"]} - ) assert r["meta"]["per_page"] == 20 assert r["meta"]["page"] == 0 assert r["meta"]["page_count"] == 4 assert {h["i"] for h in r["results"]} == set(range(20)) - r = get_json( - client, - url, - user=user, - params={"sort": "i", "per_page": 20, "page": 3, "fields": ["i"]}, - ) + r = post_json(client, url, user=user, json={"sort": "i", "per_page": 20, "page": 3, "fields": ["i"]}, expected=200) assert r["meta"]["page"] == 3 assert {h["i"] for h in r["results"]} == {60, 61, 62, 63, 64, 65} - r = get_json( - client, url, user=user, params={"sort": "i", "per_page": 20, "page": 4} - ) + r = post_json(client, url, user=user, json={"sort": "i", "per_page": 20, "page": 4, "fields": ["i"]}, expected=200) assert len(r["results"]) == 0 - # Test POST query - - r = post_json( - client, - f"/index/{index}/query", - expected=200, - user=user, - json={"sort": "i", "per_page": 20, "page": 3, "fields": ["i"]}, - ) - assert r["meta"]["page"] == 3 - assert {h["i"] for h in r["results"]} == {60, 61, 62, 63, 64, 65} def test_scroll(client, index, user): - set_role(index, user, Role.METAREADER) - upload(index, docs=[{"i": i} for i in range(66)]) - url = f"/index/{index}/documents" - r = get_json( + set_role(index, user, Role.READER) + upload(index, docs=[{"i": i} for i in range(66)], fields={"i": CreateField(type="integer")}) + url = f"/index/{index}/query" + r = post_json( client, url, user=user, - params={"sort": "i:desc", "per_page": 30, "scroll": "5m", "fields": ["i"]}, + json={"scroll": "5m", "sort": [{"i": {"order": "desc"}}], "per_page": 30, "fields": ["i"]}, + expected=200, ) + scroll_id = r["meta"]["scroll_id"] assert scroll_id is not None assert {h["i"] for h in r["results"]} == set(range(36, 66)) - r = get_json(client, url, user=user, params={"scroll_id": scroll_id}) + + r = post_json(client, url, user=user, json={"scroll_id": scroll_id}, expected=200) assert {h["i"] for h in r["results"]} == set(range(6, 36)) assert r["meta"]["scroll_id"] == scroll_id - r = get_json(client, url, user=user, params={"scroll_id": scroll_id}) + r = post_json(client, url, user=user, json={"scroll_id": scroll_id}, expected=200) assert {h["i"] for h in r["results"]} == set(range(6)) + # Scrolling past the edge should return 404 - get_json(client, url, user=user, params={"scroll_id": scroll_id}, expected=404) + post_json(client, url, user=user, json={"scroll_id": scroll_id}, expected=404) + # Test POST to query endpoint r = post_json( client, diff --git a/tests/test_api_preprocessing.py b/tests/test_api_preprocessing.py new file mode 100644 index 0000000..85e5d28 --- /dev/null +++ b/tests/test_api_preprocessing.py @@ -0,0 +1,132 @@ +import asyncio +import json +import logging +import httpx +import pytest + +from amcat4.index import Role, add_instruction, get_document, refresh_index, set_role +from amcat4.preprocessing.models import PreprocessingInstruction +from amcat4.preprocessing.processor import get_manager +from tests.conftest import TEST_DOCUMENTS +from tests.test_preprocessing import INSTRUCTION +from tests.tools import aget_json, build_headers, check, get_json + +logger = logging.getLogger("amcat4.tests") + + +def test_get_tasks(client): + # UPDATE after we make a proper 'task store' + res = client.get("/preprocessing_tasks") + res.raise_for_status() + assert any(task["name"] == "HuggingFace Zero-Shot" for task in res.json()) + + +def test_auth(client, index, user): + check(client.get(f"/index/{index}/preprocessing"), 401) + check(client.post(f"/index/{index}/preprocessing", json=INSTRUCTION), 401) + set_role(index, user, Role.READER) + + check(client.get(f"/index/{index}/preprocessing", headers=build_headers(user=user)), 200) + check(client.post(f"/index/{index}/preprocessing", json=INSTRUCTION, headers=build_headers(user=user)), 401) + + +@pytest.mark.asyncio +async def test_post_get_instructions(client, user, index_docs, httpx_mock): + i = PreprocessingInstruction.model_validate_json(json.dumps(INSTRUCTION)) + + set_role(index_docs, user, Role.WRITER) + res = client.get(f"/index/{index_docs}/preprocessing", headers=build_headers(user=user)) + res.raise_for_status() + assert len(res.json()) == 0 + + httpx_mock.add_response(url=i.endpoint, json={"labels": ["games", "sports"], "scores": [0.9, 0.1]}) + + res = client.post(f"/index/{index_docs}/preprocessing", headers=build_headers(user=user), json=i.model_dump()) + res.raise_for_status() + refresh_index(index_docs) + res = client.get(f"/index/{index_docs}/preprocessing", headers=build_headers(user=user)) + res.raise_for_status() + assert {item["field"] for item in res.json()} == {i.field} + + while len(httpx_mock.get_requests()) < len(TEST_DOCUMENTS): + await asyncio.sleep(0.1) + await asyncio.sleep(0.1) + assert all(get_document(index_docs, doc["_id"])["class_label"] == "games" for doc in TEST_DOCUMENTS) + + # Cannot re-add the same field + check(client.post(f"/index/{index_docs}/preprocessing", json=i.model_dump(), headers=build_headers(user=user)), 400) + + +@pytest.mark.asyncio +async def test_pause_restart(aclient: httpx.AsyncClient, admin, index_docs, httpx_mock, caplog): + async def slow_response(request): + json.loads(request.content)["inputs"] + await asyncio.sleep(0.1) + return httpx.Response(json={"labels": ["politics"], "scores": [1]}, status_code=200) + + i = PreprocessingInstruction.model_validate_json(json.dumps(INSTRUCTION)) + httpx_mock.add_callback(slow_response, url=i.endpoint) + status_url = f"/index/{index_docs}/preprocessing/{i.field}/status" + + # Start the preprocessor, wait .15 seconds + add_instruction(index_docs, i) + await asyncio.sleep(0.15) + + assert (await aget_json(aclient, status_url, user=admin))["status"] == "Active" + + # Set the processor to pause + check(await aclient.post(status_url, json=dict(action="Stop"), headers=build_headers(user=admin)), 204) + await asyncio.sleep(0) + assert (await aget_json(aclient, status_url, user=admin))["status"] == "Stopped" + + # Some, but not all docs should be done yet + assert len(httpx_mock.get_requests()) < len(TEST_DOCUMENTS) + assert len(httpx_mock.get_requests()) > 0 + + # Restart processor + check(await aclient.post(status_url, json=dict(action="Start"), headers=build_headers(user=admin)), 204) + await asyncio.sleep(0) + assert (await aget_json(aclient, status_url, user=admin))["status"] == "Active" + + await get_manager().running_tasks[index_docs, i.field] + assert (await aget_json(aclient, status_url, user=admin))["status"] == "Done" + + # There should be at most one extra request (the cancelled one) + assert len(httpx_mock.get_requests()) <= len(TEST_DOCUMENTS) + 1 + + +@pytest.mark.asyncio +async def test_reassign_error(aclient: httpx.AsyncClient, admin, index_docs, httpx_mock): + async def mistakes_were_made(request): + await asyncio.sleep(0.1) + input = json.loads(request.content)["inputs"] + if "text" in input: # should be true for 2 documents + return httpx.Response(json={"kettle": "black"}, status_code=418) + else: + return httpx.Response(json={"labels": ["first pass"]}, status_code=200) + + i = PreprocessingInstruction.model_validate_json(json.dumps(INSTRUCTION)) + httpx_mock.add_callback(mistakes_were_made, url=i.endpoint) + + # Start the preprocessor, wait .15 seconds + add_instruction(index_docs, i) + await get_manager().running_tasks[index_docs, i.field] + field_url = f"/index/{index_docs}/preprocessing/{i.field}" + status_url = f"{field_url}/status" + + res = await aget_json(aclient, field_url, user=admin) + assert res["status"] == "Done" + assert res["counts"] == {"total": 4, "done": 2, "error": 2} + + httpx_mock.reset(True) + httpx_mock.add_response(url=i.endpoint, json={"labels": ["secondpass"]}) + + check(await aclient.post(status_url, json=dict(action="Reassign"), headers=build_headers(user=admin)), 204) + await get_manager().running_tasks[index_docs, i.field] + + res = await aget_json(aclient, field_url, user=admin) + assert res["status"] == "Done" + assert res["counts"] == {"total": 4, "done": 4} + + # Check that only error'd documents are reassigned + assert len(httpx_mock.get_requests()) == 2 diff --git a/tests/test_api_query.py b/tests/test_api_query.py index 46d25f0..4c9f592 100644 --- a/tests/test_api_query.py +++ b/tests/test_api_query.py @@ -1,78 +1,13 @@ from amcat4.index import Role, refresh_index, set_role +from amcat4.models import CreateField, FieldSpec from amcat4.query import query_documents from tests.conftest import upload -from tests.tools import get_json, post_json, dictset - -TEST_DOCUMENTS = [ - { - "cat": "a", - "subcat": "x", - "i": 1, - "date": "2018-01-01", - "text": "this is a text", - }, - { - "cat": "a", - "subcat": "x", - "i": 2, - "date": "2018-02-01", - "text": "a test text", - }, - { - "cat": "a", - "subcat": "y", - "i": 11, - "date": "2020-01-01", - "text": "and this is another test toto", - "title": "bla", - }, - { - "cat": "b", - "subcat": "y", - "i": 31, - "date": "2018-01-01", - "text": "Toto je testovací článek", - "title": "more bla", - }, -] - - -def test_query_get(client, index_docs, user): - """Can we run a simple query?""" - - def q(**query_string): - return get_json( - client, f"/index/{index_docs}/documents", user=user, params=query_string - )["results"] - - def qi(**query_string): - return {int(doc["_id"]) for doc in q(**query_string)} - - # TODO: check auth - # Query strings - assert qi(q="text") == {0, 1} - assert qi(q="test*") == {1, 2, 3} - - # Filters - assert qi(cat="a") == {0, 1, 2} - assert qi(cat="b", q="test*") == {3} - assert qi(date="2018-01-01") == {0, 3} - assert qi(date__gte="2018-02-01") == {1, 2} - assert qi(date__gt="2018-02-01") == {2} - assert qi(date__gte="2018-02-01", date__lt="2020-01-01") == {1} - - # Can we request specific fields? - default_fields = {"_id", "date", "title"} - assert set(q()[0].keys()) == default_fields - assert set(q(fields="cat")[0].keys()) == {"_id", "cat"} - assert set(q(fields="date,title")[0].keys()) == {"_id", "date", "title"} +from tests.tools import build_headers, check, post_json, dictset def test_query_post(client, index_docs, user): def q(**body): - return post_json( - client, f"/index/{index_docs}/query", user=user, expected=200, json=body - )["results"] + return post_json(client, f"/index/{index_docs}/query", user=user, expected=200, json=body)["results"] def qi(**query_string): return {int(doc["_id"]) for doc in q(**query_string)} @@ -93,8 +28,8 @@ def qi(**query_string): assert qi(filters={"cat": {"values": ["a"]}}) == {0, 1, 2} # Can we request specific fields? - default_fields = {"_id", "date", "title"} - assert set(q()[0].keys()) == default_fields + all_fields = {"_id", "cat", "subcat", "i", "date", "text", "title"} + assert set(q()[0].keys()) == all_fields assert set(q(fields=["cat"])[0].keys()) == {"_id", "cat"} assert set(q(fields=["date", "title"])[0].keys()) == {"_id", "date", "title"} @@ -122,12 +57,8 @@ def test_aggregate(client, index_docs, user): "aggregations": [{"field": "i", "function": "avg"}], }, ) - assert dictset(r["data"]) == dictset( - [{"avg_i": 1.5, "n": 2, "subcat": "x"}, {"avg_i": 21.0, "n": 2, "subcat": "y"}] - ) - assert r["meta"]["aggregations"] == [ - {"field": "i", "function": "avg", "type": "long", "name": "avg_i"} - ] + assert dictset(r["data"]) == dictset([{"avg_i": 1.5, "n": 2, "subcat": "x"}, {"avg_i": 21.0, "n": 2, "subcat": "y"}]) + assert r["meta"]["aggregations"] == [{"field": "i", "function": "avg", "type": "integer", "name": "avg_i"}] # test filtered aggregate r = post_json( @@ -156,35 +87,66 @@ def test_aggregate(client, index_docs, user): assert data == {"x": 2} +def test_bare_aggregate(client, index_docs, user): + r = post_json( + client, + f"/index/{index_docs}/aggregate", + user=user, + expected=200, + json={}, + ) + assert r["meta"]["axes"] == [] + assert r["data"] == [dict(n=4)] + + r = post_json( + client, + f"/index/{index_docs}/aggregate", + user=user, + expected=200, + json={"aggregations": [{"field": "i", "function": "avg"}]}, + ) + assert r["data"] == [dict(n=4, avg_i=11.25)] + + r = post_json( + client, + f"/index/{index_docs}/aggregate", + user=user, + expected=200, + json={"aggregations": [{"field": "i", "function": "min", "name": "mini"}]}, + ) + assert r["data"] == [dict(n=4, mini=1)] + + def test_multiple_index(client, index_docs, index, user): set_role(index, user, Role.READER) upload( index, [{"text": "also a text", "i": -1, "cat": "c"}], - fields={"cat": "keyword", "i": "long"}, + fields={ + "text": CreateField(type="text"), + "cat": CreateField(type="keyword"), + "i": CreateField(type="integer"), + }, ) indices = f"{index},{index_docs}" - assert ( - len(get_json(client, f"/index/{indices}/documents", user=user)["results"]) == 5 - ) - assert ( - len( - post_json(client, f"/index/{indices}/query", user=user, expected=200)[ - "results" - ] - ) - == 5 + + r = post_json( + client, + f"/index/{indices}/query", + user=user, + expected=200, + json=dict(fields=["_id", "cat", "i"]), ) + assert len(r["results"]) == 5 + r = post_json( client, f"/index/{indices}/aggregate", user=user, - json={"axes": [{"field": "cat"}]}, + json={"axes": [{"field": "cat"}], "fields": ["_id"]}, expected=200, ) - assert dictset(r["data"]) == dictset( - [{"cat": "a", "n": 3}, {"n": 1, "cat": "b"}, {"n": 1, "cat": "c"}] - ) + assert dictset(r["data"]) == dictset([{"cat": "a", "n": 3}, {"n": 1, "cat": "b"}, {"n": 1, "cat": "c"}]) def test_aggregate_datemappings(client, index_docs, user): @@ -218,37 +180,41 @@ def test_aggregate_datemappings(client, index_docs, user): def test_query_tags(client, index_docs, user): def tags(): - return { - doc["_id"]: doc["tag"] - for doc in query_documents(index_docs, fields=["tag"]).data - if doc.get("tag") - } + result = query_documents(index_docs, fields=[FieldSpec(name="tag")]) + return {doc["_id"]: doc["tag"] for doc in (result.data if result else []) if doc.get("tag")} + + check(client.post(f"/index/{index_docs}/tags_update"), 401) + check(client.post(f"/index/{index_docs}/tags_update", headers=build_headers(user=user)), 401) + + set_role(index_docs, user, Role.WRITER) assert tags() == {} - post_json( + res = post_json( client, f"/index/{index_docs}/tags_update", user=user, - expected=204, + expected=200, json=dict(action="add", field="tag", tag="x", filters={"cat": "a"}), ) - refresh_index(index_docs) + assert res["updated"] == 3 + # should refresh before returning + # refresh_index(index_docs) assert tags() == {"0": ["x"], "1": ["x"], "2": ["x"]} - post_json( + res = post_json( client, f"/index/{index_docs}/tags_update", user=user, - expected=204, + expected=200, json=dict(action="remove", field="tag", tag="x", queries=["text"]), ) - refresh_index(index_docs) + assert res["updated"] == 2 assert tags() == {"2": ["x"]} - post_json( + res = post_json( client, f"/index/{index_docs}/tags_update", user=user, - expected=204, + expected=200, json=dict(action="add", field="tag", tag="y", ids=["1", "2"]), ) - refresh_index(index_docs) + assert res["updated"] == 2 assert tags() == {"1": ["y"], "2": ["x", "y"]} diff --git a/tests/test_api_user.py b/tests/test_api_user.py index afebc80..f15f206 100644 --- a/tests/test_api_user.py +++ b/tests/test_api_user.py @@ -29,6 +29,10 @@ def test_auth(client: TestClient, user, admin, index): assert client.get(f"/index/{index}", headers=build_headers(admin)).status_code == 200 with set_auth(AuthOptions.authorized_users_only): # Only users with a index-level role can access other indices (even as guest) + # KW: I don't understand what this means. Do we need to check every index? + # Now changed it so that only users with a server level role can access other indices as guest. + # In other words, in this auth mode you either need index level authorization or server level + # authorization with guest access. (this did pass the test) set_guest_role(index, Role.READER) refresh() assert client.get(f"/index/{index}").status_code == 401 @@ -45,24 +49,26 @@ def test_get_user(client: TestClient, writer, user): assert get_json(client, f"/users/{user}", user=user) == {"email": user, "role": "READER"} # writer can see everyone assert get_json(client, f"/users/{user}", user=writer) == {"email": user, "role": "READER"} - assert get_json(client, f"/users/{writer}", user=writer) == {"email": writer, "role": 'WRITER'} + assert get_json(client, f"/users/{writer}", user=writer) == {"email": writer, "role": "WRITER"} # Retrieving a non-existing user as admin should give 404 delete_user(user) - assert client.get(f'/users/{user}', headers=build_headers(writer)).status_code == 404 + assert client.get(f"/users/{user}", headers=build_headers(writer)).status_code == 404 def test_create_user(client: TestClient, user, writer, admin, username): # anonymous or unprivileged users cannot create new users - assert client.post('/users/').status_code == 401, "Creating user should require auth" + assert client.post("/users/").status_code == 401, "Creating user should require auth" assert client.post("/users/", headers=build_headers(writer)).status_code == 401, "Creating user should require admin" # users need global role - assert client.post("/users/", headers=build_headers(admin), json=dict(email=username)).status_code == 400, \ - "Duplicate create should return 400" + assert ( + client.post("/users/", headers=build_headers(admin), json=dict(email=username)).status_code == 400 + ), "Duplicate create should return 400" # admin can add new users - u = dict(email=username, role="writer") - assert "email" in set(post_json(client, "/users/", user=admin, json=u).keys()) - assert client.post("/users/", headers=build_headers(admin), json=u).status_code == 400, \ - "Duplicate create should return 400" + u = dict(email=username, role="WRITER") + assert "email" in set((post_json(client, "/users/", user=admin, json=u) or {}).keys()) + assert ( + client.post("/users/", headers=build_headers(admin), json=u).status_code == 400 + ), "Duplicate create should return 400" # users can delete themselves, others cannot delete them assert client.delete(f"/users/{username}", headers=build_headers(writer)).status_code == 401 @@ -75,8 +81,8 @@ def test_create_user(client: TestClient, user, writer, admin, username): def test_modify_user(client: TestClient, user, writer, admin): """Are the API endpoints and auth for modifying users correct?""" # Only admin can change users - check(client.put(f"/users/{user}", headers=build_headers(user), json={'role': 'metareader'}), 401) - check(client.put(f"/users/{user}", headers=build_headers(admin), json={'role': 'admin'}), 200) + check(client.put(f"/users/{user}", headers=build_headers(user), json={"role": "METAREADER"}), 401) + check(client.put(f"/users/{user}", headers=build_headers(admin), json={"role": "ADMIN"}), 200) assert get_global_role(user).name == "ADMIN" @@ -84,6 +90,6 @@ def test_list_users(client: TestClient, index, admin, user): # You need global WRITER rights to list users check(client.get("/users"), 401) check(client.get("/users", headers=build_headers(user)), 401) - result = get_json(client, "/users", user=admin) - assert {'email': admin, 'role': 'ADMIN'} in result - assert {'email': user, 'role': 'READER'} in result + result = get_json(client, "/users", user=admin) or {} + assert {"email": admin, "role": "ADMIN"} in result + assert {"email": user, "role": "READER"} in result diff --git a/tests/test_elastic.py b/tests/test_elastic.py index a4ee366..4985a89 100644 --- a/tests/test_elastic.py +++ b/tests/test_elastic.py @@ -1,56 +1,78 @@ from datetime import datetime - -from amcat4 import elastic -from amcat4.elastic import get_fields -from amcat4.index import refresh_index +from re import I + +import pytest + +from amcat4.index import ( + refresh_index, + upload_documents, + get_document, + update_document, + update_tag_by_query, +) +from amcat4.fields import create_fields, update_fields, get_fields, field_values +from amcat4.models import CreateField, FieldSpec from amcat4.query import query_documents from tests.conftest import upload def test_upload_retrieve_document(index): """Can we upload and retrieve documents""" - a = dict(text="text", title="title", date="2021-03-09", _id="test", term_tfidf=[ - {"term": "test", "value": 0.2}, - {"term": "value", "value": 0.3} - ]) - elastic.upload_documents(index, [a]) - d = elastic.get_document(index, "test") - assert d['title'] == a['title'] - assert d['term_tfidf'] == a['term_tfidf'] + a = dict( + text="text", + title="title", + date="2021-03-09", + _id="test", + term_tfidf=[{"term": "test", "value": 0.2}, {"term": "value", "value": 0.3}], + ) + upload_documents(index, [a], fields={"text": "text", "title": "text", "date": "date", "term_tfidf": "object"}) + d = get_document(index, "test") + assert d["title"] == a["title"] + assert d["term_tfidf"] == a["term_tfidf"] # TODO: should a['date'] be a datetime? def test_data_coerced(index): """Are field values coerced to the correct field type""" - elastic.set_fields(index, {"i": "long"}) - a = dict(_id="DoccyMcDocface", text="text", title="test-numeric", date="2022-12-13", i="1") - elastic.upload_documents(index, [a]) - d = elastic.get_document(index, "DoccyMcDocface") - assert isinstance(d["i"], float) + create_fields(index, {"i": "integer", "x": "number", "title": "text", "date": "date", "text": "text"}) + a = dict(_id="DoccyMcDocface", text="text", title="test-numeric", date="2022-12-13", i="1", x="1.1") + upload_documents(index, [a]) + d = get_document(index, "DoccyMcDocface") + assert isinstance(d["i"], int) a = dict(text="text", title=1, date="2022-12-13") - elastic.upload_documents(index, [a]) - d = elastic.get_document(index, "DoccyMcDocface") + upload_documents(index, [a]) + d = get_document(index, "DoccyMcDocface") assert isinstance(d["title"], str) def test_fields(index): """Can we get the fields from an index""" + create_fields(index, {"title": "text", "date": "date", "text": "text", "url": "keyword"}) fields = get_fields(index) assert set(fields.keys()) == {"title", "date", "text", "url"} - assert fields['date']['type'] == "date" + assert fields["title"].type == "text" + assert fields["date"].type == "date" + + # default settings + assert fields["date"].identifier == False + assert fields["date"].client_settings is not None + + # default settings depend on the type + assert fields["date"].metareader.access == "read" + assert fields["text"].metareader.access == "none" def test_values(index): """Can we get values for a specific field""" upload(index, [dict(bla=x) for x in ["odd", "even", "even"] * 10], fields={"bla": "keyword"}) - assert set(elastic.get_values(index, "bla")) == {"odd", "even"} + assert set(field_values(index, "bla", 10)) == {"odd", "even"} def test_update(index_docs): """Can we update a field on a document?""" - assert elastic.get_document(index_docs, '0', _source=['annotations']) == {} - elastic.update_document(index_docs, '0', {'annotations': {'x': 3}}) - assert elastic.get_document(index_docs, '0', _source=['annotations'])['annotations'] == {'x': 3} + assert get_document(index_docs, "0", _source=["annotations"]) == {} + update_document(index_docs, "0", {"annotations": {"x": 3}}) + assert get_document(index_docs, "0", _source=["annotations"])["annotations"] == {"x": 3} def test_add_tag(index_docs): @@ -58,30 +80,110 @@ def q(*ids): return dict(query=dict(ids={"values": ids})) def tags(): - return {doc['_id']: doc['tag'] - for doc in query_documents(index_docs, fields=["tag"]).data - if 'tag' in doc and doc['tag'] is not None} + res = query_documents(index_docs, fields=[FieldSpec(name="tag")]) + return {doc["_id"]: doc["tag"] for doc in (res.data if res else []) if "tag" in doc and doc["tag"] is not None} assert tags() == {} - elastic.update_tag_by_query(index_docs, "add", q('0', '1'), "tag", "x") + update_tag_by_query(index_docs, "add", q("0", "1"), "tag", "x") refresh_index(index_docs) - assert tags() == {'0': ['x'], '1': ['x']} - elastic.update_tag_by_query(index_docs, "add", q('1', '2'), "tag", "x") + assert tags() == {"0": ["x"], "1": ["x"]} + update_tag_by_query(index_docs, "add", q("1", "2"), "tag", "x") refresh_index(index_docs) - assert tags() == {'0': ['x'], '1': ['x'], '2': ['x']} - elastic.update_tag_by_query(index_docs, "add", q('2', '3'), "tag", "y") + assert tags() == {"0": ["x"], "1": ["x"], "2": ["x"]} + update_tag_by_query(index_docs, "add", q("2", "3"), "tag", "y") refresh_index(index_docs) - assert tags() == {'0': ['x'], '1': ['x'], '2': ['x', 'y'], '3': ['y']} - elastic.update_tag_by_query(index_docs, "remove", q('0', '2', '3'), "tag", "x") + assert tags() == {"0": ["x"], "1": ["x"], "2": ["x", "y"], "3": ["y"]} + update_tag_by_query(index_docs, "remove", q("0", "2", "3"), "tag", "x") refresh_index(index_docs) - assert tags() == {'1': ['x'], '2': ['y'], '3': ['y']} + assert tags() == {"1": ["x"], "2": ["y"], "3": ["y"]} -def test_deduplication(index): +def test_upload_without_identifiers(index): doc = {"title": "titel", "text": "text", "date": datetime(2020, 1, 1)} - elastic.upload_documents(index, [doc]) + res = upload_documents(index, [doc], fields={"title": "text", "text": "text", "date": "date"}) + assert res["successes"] == 1 + _assert_n(index, 1) + + # this doesnt identify duplicates + res = upload_documents(index, [doc]) + assert res["successes"] == 1 + _assert_n(index, 2) + + +def test_upload_with_explicit_ids(index): + doc = {"_id": "1", "title": "titel", "text": "text", "date": datetime(2020, 1, 1)} + res = upload_documents(index, [doc], fields={"title": "text", "text": "text", "date": "date"}) + assert res["successes"] == 1 + + # this does skip docs with same id + res = upload_documents(index, [doc]) + assert res["successes"] == 0 + _assert_n(index, 1) + + +def test_upload_with_identifiers(index): + doc = {"url": "http://", "text": "text"} + res = upload_documents(index, [doc], fields={"url": CreateField(type="keyword", identifier=True), "text": "text"}) + assert res["successes"] == 1 + _assert_n(index, 1) + + doc2 = {"url": "http://", "text": "text2"} + res = upload_documents(index, [doc2]) + assert res["successes"] == 0 + _assert_n(index, 1) + + doc3 = {"url": "http://2", "text": "text"} + res = upload_documents(index, [doc3]) + assert res["successes"] == 1 + _assert_n(index, 2) + + # cannot upload explicit id if identifiers are used + doc4 = {"_id": "1", "url": "http://", "text": "text"} + with pytest.raises(ValueError): + upload_documents(index, [doc4]) + + +def test_invalid_adding_identifiers(index): + # identifiers can only be added if (1) the index already uses identifiers or (2) the index is still empty (no docs) + doc = {"text": "text"} + upload_documents(index, [doc], fields={"text": "text"}) refresh_index(index) - assert query_documents(index).total_count == 1 - elastic.upload_documents(index, [doc]) + + # adding an identifier to an existing index should fail + doc = {"url": "http://", "text": "text"} + with pytest.raises(ValueError): + upload_documents(index, [doc], fields={"url": CreateField(type="keyword", identifier=True)}) + + +def test_valid_adding_identifiers(index): + doc = {"text": "text"} + upload_documents(index, [doc], fields={"text": CreateField(type="text", identifier=True)}) + + # adding an additional identifier to an existing index should succeed if the index already has identifiers + doc = {"url": "http://", "text": "text"} + res = upload_documents(index, [doc], fields={"url": CreateField(type="keyword", identifier=True)}) + + # the document should have been added because its not a full duplicate (in first doc url was empty) + assert res["successes"] == 1 + + # both the identifier for the first doc and the second doc should still work, so the following docs are + # both duplicates + doc1 = {"text": "text"} + doc2 = {"url": "http://", "text": "text"} + res = upload_documents(index, [doc1, doc2]) + assert res["successes"] == 0 + + # the order of adding identifiers doesn't matter. a document having just the url uses only the url as identifier + doc = {"url": "http://new"} + res = upload_documents(index, [doc]) + assert res["successes"] == 1 + # second time its a duplicate + res = upload_documents(index, [doc]) + assert res["successes"] == 0 + + +def _assert_n(index, n): refresh_index(index) - assert query_documents(index).total_count == 1 + res = query_documents(index) + assert res is not None + assert res.total_count == n diff --git a/tests/test_index.py b/tests/test_index.py index 61e5f99..543d02c 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -3,7 +3,7 @@ import pytest from amcat4.config import get_settings -from amcat4.elastic import es, set_fields +from amcat4.elastic import es from amcat4.index import ( Role, create_index, @@ -25,6 +25,8 @@ set_guest_role, set_role, ) +from amcat4.fields import update_fields +from amcat4.models import Field from tests.tools import refresh @@ -35,7 +37,7 @@ def list_es_indices() -> List[str]: return list(es().indices.get(index="*").keys()) -def list_index_names(email: str = None) -> List[str]: +def list_index_names(email: str | None = None) -> List[str]: return [ix.name for ix in list_known_indices(email)] @@ -52,9 +54,9 @@ def test_create_delete_index(): assert index in list_index_names() # Cannot create or register duplicate index with pytest.raises(Exception): - create_index(index.name) + create_index(index) with pytest.raises(Exception): - register_index(index.name) + register_index(index) delete_index(index) refresh_index(get_settings().system_index) assert index not in list_es_indices() @@ -93,7 +95,7 @@ def test_list_indices(index, guest_index, admin): def test_global_roles(): user = "user@example.com" - assert get_global_role(user) is None + assert get_global_role(user) == Role.NONE set_global_role(user, Role.ADMIN) refresh_index(get_settings().system_index) assert get_global_role(user) == Role.ADMIN @@ -102,12 +104,12 @@ def test_global_roles(): assert get_global_role(user) == Role.WRITER remove_global_role(user) refresh_index(get_settings().system_index) - assert get_global_role(user) is None + assert get_global_role(user) == Role.NONE def test_index_roles(index): user = "user@example.com" - assert get_role(index, user) is None + assert get_role(index, user) == Role.NONE set_role(index, user, Role.METAREADER) refresh_index(get_settings().system_index) assert get_role(index, user) == Role.METAREADER @@ -116,11 +118,11 @@ def test_index_roles(index): assert get_role(index, user) == Role.ADMIN remove_role(index, user) refresh_index(get_settings().system_index) - assert get_role(index, user) is None + assert get_role(index, user) == Role.NONE def test_guest_role(index): - assert get_guest_role(index) is None + assert get_guest_role(index) == Role.NONE set_guest_role(index, Role.READER) refresh() assert get_guest_role(index) == Role.READER @@ -158,13 +160,13 @@ def test_name_description(index): assert indices[index].name == "test" -def test_summary_field(index): - with pytest.raises(Exception): - modify_index(index, summary_field="doesnotexist") - with pytest.raises(Exception): - modify_index(index, summary_field="title") - set_fields(index, {"party": "keyword"}) - modify_index(index, summary_field="party") - assert get_index(index).summary_field == "party" - modify_index(index, summary_field="date") - assert get_index(index).summary_field == "date" +# def test_summary_field(index): +# with pytest.raises(Exception): +# modify_index(index, summary_field="doesnotexist") +# with pytest.raises(Exception): +# modify_index(index, summary_field="title") +# update_fields(index, {"party": Field(type="keyword", type="keyword")}) +# modify_index(index, summary_field="party") +# assert get_index(index).summary_field == "party" +# modify_index(index, summary_field="date") +# assert get_index(index).summary_field == "date" diff --git a/tests/test_multimedia.py b/tests/test_multimedia.py new file mode 100644 index 0000000..561d097 --- /dev/null +++ b/tests/test_multimedia.py @@ -0,0 +1,31 @@ +from io import BytesIO +import os +import pytest +import requests +from amcat4 import multimedia + + +def test_upload_get_multimedia(minio, index): + assert list(multimedia.list_multimedia_objects(index)) == [] + multimedia.add_multimedia_object(index, "test", b"bytes") + assert {o.object_name for o in multimedia.list_multimedia_objects(index)} == {"test"} + + +def test_presigned_form(minio, index): + pytest.skip("mock minio does not allow presigned post, skipping for now") + assert list(multimedia.list_multimedia_objects(index)) == [] + bytes = os.urandom(32) + key = "image.png" + url, form_data = multimedia.presigned_post(index, "") + res = requests.post( + url=url, + data={"key": key, **form_data}, + files={"file": BytesIO(bytes)}, + ) + res.raise_for_status() + assert {o.object_name for o in multimedia.list_multimedia_objects(index)} == {"image.png"} + + url = multimedia.presigned_get(index, key) + res = requests.get(url) + res.raise_for_status() + assert res.content == bytes diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 7988694..acfe0db 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -1,46 +1,60 @@ from typing import List - +from amcat4.models import FieldSpec from amcat4.query import query_documents def test_pagination(index_many): x = query_documents(index_many, per_page=6) + assert x is not None assert x.page_count == 4 assert x.per_page == 6 assert len(x.data) == 6 assert x.page == 0 x = query_documents(index_many, per_page=6, page=3) + assert x is not None assert x.page_count == 4 assert x.per_page == 6 - assert len(x.data) == 20 - 3*6 + assert len(x.data) == 20 - 3 * 6 assert x.page == 3 def test_sort(index_many): - def q(key, per_page=5, *args, **kwargs) -> List[int]: - res = query_documents(index_many, per_page=per_page, sort=key, *args, **kwargs) - return [int(h['_id']) for h in res.data] - assert q('id') == [0, 1, 2, 3, 4] - assert q('pagenr') == [10, 9, 11, 8, 12] - assert q(['pagenr', 'id']) == [10, 9, 11, 8, 12] - assert q([{'pagenr': {"order": "desc"}}, 'id']) == [0, 1, 19, 2, 18] + def q(key, per_page=5) -> List[int]: + + for i, k in enumerate(key): + if isinstance(k, str): + key[i] = {k: {"order": "asc"}} + res = query_documents(index_many, per_page=per_page, fields=[FieldSpec(name="id")], sort=key) + assert res is not None + + print(list(res.data)) + return [int(h["id"]) for h in res.data] + + assert q(["id"]) == [0, 1, 2, 3, 4] + assert q(["pagenr"]) == [10, 9, 11, 8, 12] + assert q(["pagenr", "id"]) == [10, 9, 11, 8, 12] + assert q([{"pagenr": {"order": "desc"}}, "id"]) == [0, 1, 19, 2, 18] def test_scroll(index_many): - r = query_documents(index_many, queries=["odd"], scroll='5m', per_page=4) + r = query_documents(index_many, queries={"odd": "odd"}, scroll="5m", per_page=4, fields=[FieldSpec(name="id")]) + assert r is not None assert len(r.data) == 4 assert r.total_count, 10 assert r.page_count == 3 allids = list(r.data) - r = query_documents(index_many, scroll_id=r.scroll_id) + r = query_documents(index_many, scroll_id=r.scroll_id, fields=[FieldSpec(name="id")]) + assert r is not None assert len(r.data) == 4 allids += r.data - r = query_documents(index_many, scroll_id=r.scroll_id) + r = query_documents(index_many, scroll_id=r.scroll_id, fields=[FieldSpec(name="id")]) + assert r is not None assert len(r.data) == 2 allids += r.data - r = query_documents(index_many, scroll_id=r.scroll_id) + r = query_documents(index_many, scroll_id=r.scroll_id, fields=[FieldSpec(name="id")]) assert r is None - assert {int(h['_id']) for h in allids} == {0, 2, 4, 6, 8, 10, 12, 14, 16, 18} + + assert {int(h["id"]) for h in allids} == {0, 2, 4, 6, 8, 10, 12, 14, 16, 18} diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py new file mode 100644 index 0000000..34bc584 --- /dev/null +++ b/tests/test_preprocessing.py @@ -0,0 +1,170 @@ +import asyncio +import time +import httpx +from pytest_httpx import HTTPXMock +import json + +import pytest +from amcat4.fields import create_fields +from amcat4.index import get_document, reassign_preprocessing_errors, refresh_index, upload_documents, add_instruction +from amcat4.preprocessing.models import PreprocessingInstruction +from amcat4.preprocessing import processor + +from amcat4.preprocessing.processor import ( + get_counts, + get_manager, + get_todo, + process_doc, + process_documents, +) +from tests.conftest import TEST_DOCUMENTS + +INSTRUCTION = dict( + field="preprocess_label", + task="HuggingFace Zero-Shot", + endpoint="https://api-inference.huggingface.co/models/facebook/bart-large-mnli", + arguments=[{"name": "input", "field": "text"}, {"name": "candidate_labels", "value": ["politics", "sports"]}], + outputs=[{"name": "label", "field": "class_label"}], +) + + +def test_build_request(index): + i = PreprocessingInstruction.model_validate_json(json.dumps(INSTRUCTION)) + doc = dict(text="Sample text") + req = i.build_request(index, doc) + assert req.url == INSTRUCTION["endpoint"] + assert json.loads(req.content) == dict(inputs=doc["text"], parameters=dict(candidate_labels=["politics", "sports"])) + + +def test_parse_result(): + i = PreprocessingInstruction.model_validate_json(json.dumps(INSTRUCTION)) + output = {"labels": ["politics", "sports"], "scores": [0.9, 0.1]} + update = dict(i.parse_output(output)) + assert update == dict(class_label="politics") + + +@pytest.mark.asyncio +async def test_preprocess(index_docs, httpx_mock: HTTPXMock): + """Test logic of process_doc and get_todo calls""" + i = PreprocessingInstruction.model_validate_json(json.dumps(INSTRUCTION)) + httpx_mock.add_response(url=i.endpoint, json={"labels": ["politics", "sports"], "scores": [0.9, 0.1]}) + + # Create a preprocess fields. There should now be |docs| todo + create_fields(index_docs, {i.field: "preprocess"}) + todos = list(get_todo(index_docs, i)) + assert all(set(todo.keys()) == {"_id", "text"} for todo in todos) + assert {doc["_id"] for doc in todos} == {str(doc["_id"]) for doc in TEST_DOCUMENTS} + + # Process a single document. Check that it's done, and that the todo list is now one shorter + todo = sorted(todos, key=lambda todo: todo["_id"])[0] + await process_doc(index_docs, i, todo) + doc = get_document(index_docs, todo["_id"]) + assert doc[i.field] == {"status": "done"} + assert doc["class_label"] == "politics" + refresh_index(index_docs) + todos = list(get_todo(index_docs, i)) + assert {doc["_id"] for doc in todos} == {str(doc["_id"]) for doc in TEST_DOCUMENTS} - {todo["_id"]} + + # run a single preprocessing loop, check that done is False and that + done = await process_documents(index_docs, i, size=2) + assert done == False + refresh_index(index_docs) + todos = list(get_todo(index_docs, i)) + assert len(todos) == len(TEST_DOCUMENTS) - (2 + 1) + + # run preprocessing until it returns done = True + while not done: + done = await process_documents(index_docs, i, size=2) + + # Todo should be empty, and there should be one call per document! + refresh_index(index_docs) + todos = list(get_todo(index_docs, i)) + assert len(todos) == 0 + assert len(httpx_mock.get_requests()) == len(TEST_DOCUMENTS) + + +@pytest.mark.asyncio +async def test_preprocess_loop(index_docs, httpx_mock: HTTPXMock): + """Test that adding an instruction automatically processes all docs in an index""" + i = PreprocessingInstruction.model_validate_json(json.dumps(INSTRUCTION)) + httpx_mock.add_response(url=i.endpoint, json={"labels": ["politics", "sports"], "scores": [0.9, 0.1]}) + add_instruction(index_docs, i) + await get_manager().running_tasks[index_docs, i.field] + assert len(httpx_mock.get_requests()) == len(TEST_DOCUMENTS) + assert all(get_document(index_docs, doc["_id"])["class_label"] == "politics" for doc in TEST_DOCUMENTS) + + +@pytest.mark.asyncio +async def test_preprocess_logic(index, httpx_mock: HTTPXMock): + """Test that main processing loop works correctly""" + i = PreprocessingInstruction.model_validate_json(json.dumps(INSTRUCTION)) + + async def mock_slow_response(_request) -> httpx.Response: + await asyncio.sleep(0.5) + return httpx.Response(json={"labels": ["politics"], "scores": [1]}, status_code=200) + + httpx_mock.add_callback(mock_slow_response, url=i.endpoint) + + # Add the instruction. Since there are no documents, it should return instantly-ish + add_instruction(index, i) + await asyncio.sleep(0.1) + assert get_manager().get_status(index, i.field) == "Done" + + # Add a document. The task should be re-activated and take half a second to complete + upload_documents(index, [{"text": "text"}], fields={"text": "text"}) + await asyncio.sleep(0.1) + assert get_manager().get_status(index, i.field) == "Active" + await asyncio.sleep(0.5) + assert get_manager().get_status(index, i.field) == "Done" + + +@pytest.mark.asyncio +async def test_preprocess_ratelimit(index_docs, httpx_mock: HTTPXMock): + """Test that processing is paused on hitting rate limit, and restarts automatically""" + i = PreprocessingInstruction.model_validate_json(json.dumps(INSTRUCTION)) + httpx_mock.add_response(url=i.endpoint, status_code=503) + + # Set a low pause time for the test + processor.PAUSE_ON_RATE_LIMIT_SECONDS = 0.5 + + # Start the async preprocessing loop. Receiving a 503 it should sleep for and retry + add_instruction(index_docs, i) + await asyncio.sleep(0.1) + assert get_manager().get_status(index_docs, i.field) == "Paused" + + # Now mock a success response and wait for .5 seconds + httpx_mock.reset(assert_all_responses_were_requested=True) + httpx_mock.add_response(url=i.endpoint, json={"labels": ["politics", "sports"], "scores": [0.9, 0.1]}) + await asyncio.sleep(0.5) + assert get_manager().get_status(index_docs, i.field) == "Done" + + +@pytest.mark.asyncio +async def test_preprocess_error(index_docs, httpx_mock: HTTPXMock): + """Test that errors are reported correctly""" + i = PreprocessingInstruction.model_validate_json(json.dumps(INSTRUCTION)) + + def some_errors(request): + input = json.loads(request.content)["inputs"] + if "text" in input: # should be true for 2 documents + return httpx.Response(json={"error": "I'm a teapot!"}, status_code=418) + else: + return httpx.Response(json={"labels": ["politics"], "scores": [1]}, status_code=200) + + httpx_mock.add_callback(some_errors, url=i.endpoint) + add_instruction(index_docs, i) + await get_manager().running_tasks[index_docs, i.field] + for doc in TEST_DOCUMENTS: + result = get_document(index_docs, doc["_id"]) + assert result[i.field]["status"] == "error" if "text" in doc["text"] else "done" + assert get_counts(index_docs, i.field) == dict(total=4, done=2, error=2) + + httpx_mock.reset(True) + httpx_mock.add_response(url=i.endpoint, json={"labels": ["sports"], "scores": [1]}) + reassign_preprocessing_errors(index_docs, i.field) + await get_manager().running_tasks[index_docs, i.field] + for doc in TEST_DOCUMENTS: + result = get_document(index_docs, doc["_id"]) + assert result[i.field]["status"] == "done" + assert result["class_label"] == "sports" if "text" in doc["text"] else "politics" + assert len(httpx_mock.get_requests()) == 2 diff --git a/tests/test_query.py b/tests/test_query.py index 3491c49..ac8bfa4 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -1,68 +1,107 @@ import functools -import re from typing import Set, Optional from amcat4 import query +from amcat4.models import FieldSpec, FilterSpec, FilterValue, SnippetParams, UpdateField +from amcat4.api.query import _standardize_queries, _standardize_filters from tests.conftest import upload -def query_ids(index: str, q: Optional[str] = None, **kwargs) -> Set[int]: +def query_ids( + index: str, + q: Optional[str | list[str]] = None, + filters: dict[str, FilterValue | list[FilterValue] | FilterSpec] | None = None, + **kwargs, +) -> Set[int]: if q is not None: - kwargs['queries'] = [q] + kwargs["queries"] = _standardize_queries(q) + if filters is not None: + kwargs["filters"] = _standardize_filters(filters) + res = query.query_documents(index, **kwargs) - return {int(h['_id']) for h in res.data} + if res is None: + return set() + return {int(h["_id"]) for h in res.data} def test_query(index_docs): q = functools.partial(query_ids, index_docs) + assert q("test") == {1, 2} assert q("test*") == {1, 2, 3} assert q('"a text"') == {0} - assert q(queries=["this", "toto"]) == {0, 2, 3} + assert q(["this", "toto"]) == {0, 2, 3} - assert q(filters={"title": {"value": "title"}}) == {0, 1} - assert q("this", filters={"title": {"value": "title"}}) == {0} + assert q(filters={"title": ["title"]}) == {0, 1} + assert q("this", filters={"title": ["title"]}) == {0} assert q("this") == {0, 2} +def test_snippet(index_docs): + docs = query.query_documents(index_docs, fields=[FieldSpec(name="text", snippet=SnippetParams(nomatch_chars=5))]) + assert docs is not None + assert docs.data[0]["text"] == "this is" + + docs = query.query_documents( + index_docs, queries={"1": "a"}, fields=[FieldSpec(name="text", snippet=SnippetParams(max_matches=1, match_chars=1))] + ) + assert docs is not None + assert docs.data[0]["text"] == "a" + + def test_range_query(index_docs): q = functools.partial(query_ids, index_docs) - assert q(filters={"date": {"gt": "2018-02-01"}}) == {2} - assert q(filters={"date": {"gte": "2018-02-01"}}) == {1, 2} - assert q(filters={"date": {"gte": "2018-02-01", "lt": "2020-01-01"}}) == {1} - assert q("title", filters={"date": {"gt": "2018-01-01"}}) == {1} + assert q(filters={"date": FilterSpec(gt="2018-02-01")}) == {2} + assert q(filters={"date": FilterSpec(gte="2018-02-01")}) == {1, 2} + assert q(filters={"date": FilterSpec(gte="2018-02-01", lt="2020-01-01")}) == {1} + assert q("title", filters={"date": FilterSpec(gt="2018-01-01")}) == {1} def test_fields(index_docs): - res = query.query_documents(index_docs, queries=["test"], fields=["cat", "title"]) + res = query.query_documents(index_docs, queries={"1": "test"}, fields=[FieldSpec(name="cat"), FieldSpec(name="title")]) + assert res is not None assert set(res.data[0].keys()) == {"cat", "title", "_id"} def test_highlight(index): words = "The error of regarding functional notions is not quite equivalent to" text = f"{words} a test document. {words} other text documents. {words} you!" - upload(index, [dict(title="Een test titel", text=text)]) - res = query.query_documents(index, queries=["te*"], highlight=True) + upload(index, [dict(title="Een test titel", text=text)], fields={"title": "text", "text": "text"}) + res = query.query_documents( + index, fields=[FieldSpec(name="title"), FieldSpec(name="text")], queries={"1": "te*"}, highlight=True + ) + assert res is not None doc = res.data[0] - assert doc['title'] == "Een test titel" - assert doc['text'] == f"{words} a test document. {words} other text documents. {words} you!" - - doc = query.query_documents(index, queries=["te*"], highlight={"number_of_fragments": 1}).data[0] - assert doc['title'] == "Een test titel" - assert " a test" in doc['text'] - assert "..." not in doc['text'] - - doc = query.query_documents(index, queries=["te*"], highlight={"number_of_fragments": 2}).data[0] - assert re.search(r" a test[^<]*...[^<]*other text documents", doc['text']) + assert doc["title"] == "Een test titel" + assert doc["text"] == f"{words} a test document. {words} other text documents. {words} you!" + + res = query.query_documents( + index, + queries={"1": "te*"}, + fields=[ + FieldSpec(name="title", snippet=SnippetParams(max_matches=3, match_chars=50)), + FieldSpec(name="text", snippet=SnippetParams(max_matches=3, match_chars=50)), + ], + highlight=True, + ) + assert res is not None + doc = res.data[0] + assert doc["title"] == "Een test titel" + assert " a test" in doc["text"] + assert " ... " in doc["text"] def test_query_multiple_index(index_docs, index): - upload(index, [{"text": "also a text", "i": -1}]) - assert len(query.query_documents([index_docs, index]).data) == 5 - - -def test_query_filter_mapping(index_docs): - q = functools.partial(query_ids, index_docs) - assert q(filters={"date": {"monthnr": "2"}}) == {1} - assert q(filters={"date": {"dayofweek": "Monday"}}) == {0, 3} + upload(index, [{"text": "also a text", "i": -1}], fields={"i": "integer", "text": "text"}) + docs = query.query_documents([index_docs, index]) + assert docs is not None + assert len(docs.data) == 5 + + +# TODO: Do we want to support this? What are the options? +# If so, need to add it to FilterSpec +# def test_query_filter_mapping(index_docs): +# q = functools.partial(query_ids, index_docs) +# assert q(filters={"date": {"monthnr": "2"}}) == {1} +# assert q(filters={"date": {"dayofweek": "Monday"}}) == {0, 3} diff --git a/tests/tools.py b/tests/tools.py index 1a7e27f..c522c91 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -3,16 +3,16 @@ from datetime import datetime, date from typing import Set, Iterable, Optional -import requests from authlib.jose import jwt from fastapi.testclient import TestClient +from httpx import AsyncClient from amcat4.config import AuthOptions, get_settings from amcat4.index import refresh_index from tests.middlecat_keypair import PRIVATE_KEY -def create_token(**payload) -> bytes: +def create_token(**payload) -> str: header = {"alg": "RS256"} token = jwt.encode(header, payload, PRIVATE_KEY) return token.decode("utf-8") @@ -31,24 +31,30 @@ def build_headers(user=None, headers=None): return headers -def get_json(client: TestClient, url, expected=200, headers=None, user=None, **kargs): +def get_json(client: TestClient, url: str, expected=200, headers=None, user=None, **kargs) -> dict: """Get the given URL. If expected is 2xx, return the result as parsed json""" response = client.get(url, headers=build_headers(user, headers), **kargs) content = response.json() if response.content else None - assert ( - response.status_code == expected - ), f"GET {url} returned {response.status_code}, expected {expected}, {content}" - if expected // 100 == 2: - return content + assert response.status_code == expected, f"GET {url} returned {response.status_code}, expected {expected}, {content}" + return {} if content is None else content + + +async def aget_json(client: AsyncClient, url: str, expected=200, headers=None, user=None, **kargs) -> dict: + """Get the given URL. If expected is 2xx, return the result as parsed json""" + response = await client.get(url, headers=build_headers(user, headers), **kargs) + content = response.json() if response.content else None + assert response.status_code == expected, f"GET {url} returned {response.status_code}, expected {expected}, {content}" + return {} if content is None else content def post_json(client: TestClient, url, expected=201, headers=None, user=None, **kargs): response = client.post(url, headers=build_headers(user, headers), **kargs) assert response.status_code == expected, ( - f"POST {url} returned {response.status_code}, expected {expected}\n" - f"{response.json()}" + f"POST {url} returned {response.status_code}, expected {expected}\n" f"{response.json()}" ) - if expected != 204: + if expected == 204: + return {} + else: return response.json() @@ -64,7 +70,7 @@ def dictset(dicts: Iterable[dict]) -> Set[str]: return {json.dumps(dict(sorted(d.items())), cls=DateTimeEncoder) for d in dicts} -def check(response: requests.Response, expected: int, msg: Optional[str] = None): +def check(response, expected: int, msg: Optional[str] = None): assert response.status_code == expected, ( f"{msg or ''}{': ' if msg else ''}Unexpected status: received {response.status_code} != expected {expected};" f" reply: {response.json()}"