From e21ee15b36eada4a2d4a1ebfbd9737289d49687c Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Mon, 25 Dec 2023 15:12:32 +0100 Subject: [PATCH 01/80] new server updates for new client. Currently breaking due to changes in metareader field access --- amcat4/api/index.py | 26 ++++++++- amcat4/api/query.py | 76 +++++++++++++++++-------- amcat4/elastic.py | 24 ++++++-- amcat4/query.py | 108 +++++++++++++++--------------------- tests/__init__.py | 0 tests/test_api_documents.py | 28 +++++----- tests/test_elastic.py | 2 +- 7 files changed, 156 insertions(+), 108 deletions(-) create mode 100644 tests/__init__.py diff --git a/amcat4/api/index.py b/amcat4/api/index.py index fab1791..3fb5584 100644 --- a/amcat4/api/index.py +++ b/amcat4/api/index.py @@ -291,10 +291,30 @@ def set_fields( @app_index.get("/{ix}/fields/{field}/values") -def get_values(ix: str, field: str, _=Depends(authenticated_user)): - """Get the fields (columns) used in this index.""" - return elastic.get_values(ix, field, size=100) +def get_field_values(ix: str, field: str, user: str =Depends(authenticated_user)): + """ + Get unique values for a specific field. Should mainly/only be used for tag fields. + Main purpose is to provide a list of values for a dropdown menu. + + TODO: at the moment 'only' returns top 2000 values. Currently throws an + error if there are more than 2000 unique values. We can increase this limit, but + there should be a limit. Querying could be an option, but not sure if that is + efficient, since elastic has to aggregate all values first. + """ + check_role(user, Role.READER, ix) + values = elastic.get_field_values(ix, field, size=2001) + if len(values) > 2000: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Field {field} has more than 2000 unique values", + ) + return values +@app_index.get("/{ix}/fields/{field}/stats") +def get_field_stats(ix: str, field: str, user: str =Depends(authenticated_user)): + """Get statistics for a specific value. Only works for numeric (incl date) fields.""" + check_role(user, Role.READER, ix) + return elastic.get_field_stats(ix, field) @app_index.get("/{ix}/users") def list_index_users(ix: str, user: str = Depends(authenticated_user)): diff --git a/amcat4/api/query.py b/amcat4/api/query.py index a923b06..7317c7e 100644 --- a/amcat4/api/query.py +++ b/amcat4/api/query.py @@ -6,7 +6,7 @@ from fastapi.params import Body from pydantic.main import BaseModel -from amcat4 import query, aggregate +from amcat4 import elastic, query, aggregate from amcat4.aggregate import Axis, Aggregation from amcat4.api.auth import authenticated_user, check_role from amcat4.index import Role @@ -33,12 +33,24 @@ class QueryResult(BaseModel): def _check_query_role( - indices: List[str], user: str, fields: List[str], highlight: bool + indices: List[str], user: str, fields: List[str], snippets: Optional[List[str]] = None ): - if (not fields) or ("text" in fields) or (highlight): - role = Role.READER - else: + # TODO: index setting for which fields METAREADERS can see. + # Each field should say both: metareader_visible and metareader_snippet + metareader_visible = ["date", "title", "url"] + metareader_snippet = ["text"] + + all_values_in = lambda a, b: all([x in b for x in a]) + meta_visible = (not fields) or all_values_in(fields, metareader_visible) + meta_visible_snippet = (not snippets) or all_values_in(snippets, metareader_snippet) + + + if (meta_visible and meta_visible_snippet): role = Role.METAREADER + else: + role = Role.READER + + print(role) for ix in indices: check_role(user, role, ix) @@ -63,6 +75,15 @@ def get_documents( description="Comma separated list of fields to return", pattern=r"\w+(,\w+)*", ), + snippets: str = Query( + None, + description="Comma separated list of fields to return as snippets", + pattern=r"\w+(,\w+)*", + ), + highlight: bool = Query( + False, + description="If true, highlight fields" + ), per_page: int = Query(None, description="Number of results per page"), page: int = Query(None, description="Page to fetch"), scroll: str = Query( @@ -71,12 +92,6 @@ def get_documents( examples="3m", ), scroll_id: str = Query(None, description="Get the next batch from this scroll id"), - highlight: bool = Query(False, description="add highlight tags "), - annotations: bool = Query( - False, - description="if true, also return _annotations " - "with query matches as annotations", - ), user: str = Depends(authenticated_user), ): """ @@ -94,13 +109,21 @@ def get_documents( fields = fields and fields.split(",") if not fields: fields = ["date", "title", "url"] - _check_query_role(indices, user, fields, highlight) + + snippets = snippets and snippets.split(",") + if snippets: + for field in fields: + if field in snippets: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Field {field} cannot be in both fields and snippets") + + _check_query_role(indices, user, fields, snippets) + args = {} sort = sort and [ {x.replace(":desc", ""): "desc"} if x.endswith(":desc") else x for x in sort.split(",") ] - known_args = ["page", "per_page", "scroll", "scroll_id", "highlight", "annotations"] + known_args = ["page", "per_page", "scroll", "scroll_id", "highlight"] for name in known_args: val = locals()[name] if val: @@ -186,6 +209,9 @@ def query_documents_post( fields: Optional[List[str]] = Body( None, description="List of fields to retrieve for each document" ), + snippets: Optional[List[str]] = Body( + None, description="Fields to retrieve as snippets" + ), filters: Optional[ Dict[str, Union[FilterValue, List[FilterValue], FilterSpec]] ] = Body( @@ -221,13 +247,9 @@ def query_documents_post( scroll_id: Optional[str] = Body( None, description="Scroll id from previous response to continue scrolling" ), - annotations: Optional[bool] = Body( - None, description="Return _annotations with query matches as annotations" - ), - highlight: Optional[Union[bool, Dict]] = Body( - None, - description="Highlight document. 'true' highlights whole document, see elastic docs for dict format" - "https://www.elastic.co/guide/en/elasticsearch/reference/7.17/highlighting.html", + highlight: Optional[bool] = Body( + False, + description="If true, highlight fields" ), user=Depends(authenticated_user), ): @@ -240,12 +262,20 @@ def query_documents_post( # Standardize fields, queries and filters to their most versatile format indices = index.split(",") if fields: - # to array format: fields: [field1, field2] if isinstance(fields, str): fields = [fields] else: fields = ["date", "title", "url"] - _check_query_role(indices, user, fields, highlight is not None) + + if snippets: + if isinstance(snippets, str): + snippets = [snippets] + for field in fields: + if field in snippets: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Field {field} cannot be in both fields and snippets") + + #field_meta = elastic.get_fields(index) + _check_query_role(indices, user, fields, snippets) queries = _process_queries(queries) filters = dict(_process_filters(filters)) @@ -254,12 +284,12 @@ def query_documents_post( queries=queries, filters=filters, fields=fields, + snippets=snippets, sort=sort, per_page=per_page, page=page, scroll_id=scroll_id, scroll=scroll, - annotations=annotations, highlight=highlight, ) if r is None: diff --git a/amcat4/elastic.py b/amcat4/elastic.py index a9e33f8..90fc4ff 100644 --- a/amcat4/elastic.py +++ b/amcat4/elastic.py @@ -307,16 +307,32 @@ def get_fields(index: Union[str, Sequence[str]]): return result -def get_values(index: str, field: str, size: int = 100) -> List[str]: +def get_field_values(index: str, field: str, size: int) -> List[str]: """ Get the values for a given field (e.g. to populate list of filter values on keyword field) + Results are sorted descending by document frequency + see: https://www.elastic.co/guide/en/elasticsearch/reference/7.4/search-aggregations-bucket-terms-aggregation.html#search-aggregations-bucket-terms-aggregation-order + :param index: The index :param field: The field name :return: A list of values """ - aggs = {"values": {"terms": {"field": field}}} - r = es().search(index=index, size=size, aggs=aggs) - return [x["key"] for x in r["aggregations"]["values"]["buckets"]] + aggs = {"unique_values": { + "terms": {"field": field, "size": size} + }} + r = es().search(index=index, size=0, aggs=aggs) + return [x["key"] for x in r["aggregations"]["unique_values"]["buckets"]] + +def get_field_stats(index: str, field: str) -> List[str]: + """ + Get field statistics, such as min, max, avg, etc. + :param index: The index + :param field: The field name + :return: A list of values + """ + aggs = {"facets": {"stats": {"field": field}}} + r = es().search(index=index, size=0, aggs=aggs) + return r["aggregations"]["facets"] def update_by_query(index: str, script: str, query: dict, params: dict = None): diff --git a/amcat4/query.py b/amcat4/query.py index 05ce3cf..b60ab16 100644 --- a/amcat4/query.py +++ b/amcat4/query.py @@ -2,6 +2,7 @@ All things query """ from math import ceil +import json from re import finditer from re import sub from typing import Mapping, Iterable, Optional, Union, Sequence, Any, Dict, List, Tuple, Literal @@ -10,7 +11,7 @@ from .elastic import es, update_tag_by_query -def build_body(queries: Iterable[str] = None, filters: Mapping = None, highlight: Union[bool, dict] = False, +def build_body(queries: Iterable[str] = None, filters: Mapping = None, highlight: dict = None, ids: Iterable[str] = None): def parse_filter(field, filter) -> Tuple[Mapping, Mapping]: filter = filter.copy() @@ -66,13 +67,10 @@ def parse_queries(qs: Sequence[str]) -> dict: body: Dict[str, Any] = {"query": {"bool": {"filter": fs}}} if runtime_mappings: body['runtime_mappings'] = runtime_mappings - if highlight is True: - highlight = {"number_of_fragments": 0} - elif highlight: - highlight = {**{"number_of_fragments": 0, "fragment_size": 40, "type": "plain"}, **highlight} - if highlight: - body['highlight'] = {"type": 'unified', "require_field_match": True, - "fields": {"*": highlight}} + + if highlight is not None: + body["highlight"] = highlight + return body @@ -110,9 +108,10 @@ def _normalize_queries(queries: Optional[Union[Dict[str, str], Iterable[str]]]) def query_documents(index: Union[str, Sequence[str]], queries: Union[Mapping[str, str], Iterable[str]] = None, *, page: int = 0, per_page: int = 10, - scroll=None, scroll_id: str = None, fields: Iterable[str] = None, + scroll=None, scroll_id: str = None, + fields: Iterable[str] = None, snippets: Iterable[str] = None, filters: Mapping[str, Mapping] = None, - highlight: Union[bool, dict] = False, annotations=False, + highlight: Literal["none", "text", "snippets"] = "none", sort: List[Union[str, Mapping]] = None, **kwargs) -> Optional[QueryResult]: """ @@ -130,15 +129,13 @@ def query_documents(index: Union[str, Sequence[str]], queries: Union[Mapping[str specify the time the context should be kept alive, or True to get the default of 2m. :param scroll_id: if not None, should be a previously returned context_id to retrieve a new page of results :param fields: if not None, specify a list of fields to retrieve for each hit + :param snippets: if not None, specify a list of fields to retrieve snippets for :param filters: if not None, a dict of filters with either value, values, or gte/gt/lte/lt ranges: {field: {'values': [value1,value2], 'value': value, 'gte/gt/lte/lt': value, ...}} - :param highlight: if True, add highlight tags () to all results. - If a dict, it can be used to control highlighting, e.g. to get multiple snippets - (https://www.elastic.co/guide/en/elasticsearch/reference/7.17/highlighting.html) - :param annotations: if True, get query matches as annotations. + :param highlight: if True, add tags to query matches in fields :param sort: Sort order of results, can be either a single field or a list of fields. In the list, each field is a string or a dict with options, e.g. ["id", {"date": {"order": "desc"}}] (https://www.elastic.co/guide/en/elasticsearch/reference/current/sort-search-results.html) @@ -156,8 +153,9 @@ def query_documents(index: Union[str, Sequence[str]], queries: Union[Mapping[str if not result['hits']['hits']: return None else: - body = build_body(queries.values(), filters, highlight) - + h = query_highlight(fields, highlight, snippets) + body = build_body(queries.values(), filters, h) + if fields: fields = fields if isinstance(fields, list) else list(fields) kwargs['_source'] = fields @@ -168,8 +166,7 @@ def query_documents(index: Union[str, Sequence[str]], queries: Union[Mapping[str data = [] for hit in result['hits']['hits']: hitdict = dict(_id=hit['_id'], **hit['_source']) - if annotations: - hitdict['_annotations'] = list(query_annotations(index, hit['_id'], queries)) + hitdict = overwrite_highlight_results(hit, hitdict) if 'highlight' in hit: for key in hit['highlight'].keys(): if hit['highlight'][key]: @@ -183,55 +180,40 @@ def query_documents(index: Union[str, Sequence[str]], queries: Union[Mapping[str return QueryResult(data, n=result['hits']['total']['value'], per_page=per_page, page=page) -def query_annotations(index: str, id: str, queries: Mapping[str, str]) -> Iterable[Dict]: +def query_highlight(fields: Iterable[str], highlight: bool, snippets: Iterable[str]): """ - get query matches in annotation format. Currently does so per hit per query. - Per hit could be optimized, but per query seems necessary: - https://stackoverflow.com/questions/44621694/elasticsearch-highlight-with-multiple-queries-not-work-as-expected + The elastic "highlight" parameters works for both highlighting text fields and adding snippets. + This function will return the highlight parameter to be added to the query body. """ - - if not queries: - return - for label, query in queries.items(): - body = build_body([query], {'_id': {'value': id}}, True) - - result = es().search(index=index, body=body) - hit = result['hits']['hits'] - if len(hit) == 0: - continue - for field, highlights in hit[0]['highlight'].items(): - text = hit[0]["_source"][field] - if isinstance(text, list): - continue - for span in extract_highlight_span(text, highlights[0]): - span['variable'] = 'query' - span['value'] = label - span['field'] = field - yield span - - -def extract_highlight_span(text: str, highlight: str): + if (fields is None or highlight is False) and (snippets is None): + return None + + highlight = {"require_field_match": False, "fields": {}} + + if fields is not None: + for field in fields: + highlight["fields"][field] = {"number_of_fragments": 0} + + if snippets is not None: + # TODO: get index meta data to see which snippets are allowed and what + # the nr and size should be + for field in snippets: + highlight["fields"][field] = {"no_match_size": 200, "number_of_fragments": 3, "fragment_size": 40} + + return highlight + +def overwrite_highlight_results(hit: dict, hitdict: dict): """ - It doesn't seem possible to get the offsets of highlights: - https://github.com/elastic/elasticsearch/issues/5736 - - We can get the offsets from the tags, but not yet sure how stable this is. - text is the text in the _source field. highlight should be elastics highlight if nr of fragments = 0 (i.e. full text) + highlights are a separate field in the hits. If highlight is True, we want to overwrite + the original field with the highlighted version. If there are snippets, we want to add them """ - # elastic highlighting internally trims... - # this hack gets the offset of the trimmed text, but it's not an ideal solution - trimmed_offset = len(text) - len(text.lstrip()) - - side_by_side = ' ' - highlight = sub(side_by_side, ' ', highlight) - regex = '.+?' - tagsize = 9 # - for i, m in enumerate(finditer(regex, highlight)): - offset = trimmed_offset + m.start(0) - tagsize*i - length = len(m.group(0)) - tagsize - yield dict(offset=offset, length=length) - - + if not hit.get('highlight'): + return hitdict + for key in hit['highlight'].keys(): + if hit['highlight'][key]: + hitdict[key] = " ... ".join(hit['highlight'][key]) + return hitdict + def update_tag_query(index: Union[str, Sequence[str]], action: Literal["add", "remove"], field: str, tag: str, queries: Union[Mapping[str, str], Iterable[str]] = None, diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_api_documents.py b/tests/test_api_documents.py index 39f4bde..e401778 100644 --- a/tests/test_api_documents.py +++ b/tests/test_api_documents.py @@ -69,29 +69,29 @@ def test_metareader(client, index, index_docs, user, reader): def get_join(x): return ",".join(x) if isinstance(x, list) else x - # Metareader should not be able to query text (including highlight) - for ix, u, fields, highlight, outcome in [ - (index, user, ["text"], False, 401), - (index_docs, user, ["text"], False, 200), - ([index_docs, index], user, ["text"], False, 401), - (index, user, ["text", "title"], False, 401), - (index, user, ["title"], False, 200), - (index, reader, ["text"], False, 200), - ([index_docs, index], reader, ["text"], False, 200), - (index, user, ["title"], True, 401), - (index, reader, ["title"], True, 200), + # Metareader should not be able to query text as a field. Only as a snippet + for ix, u, fields, snippets, outcome in [ + (index, user, ["text"], None, 401), + (index_docs, user, ["text"], None, 200), + ([index_docs, index], user, ["text"], None, 401), + (index, user, ["text", "title"], None, 401), + (index, user, ["title"], None, 200), + (index, reader, ["text"], None, 200), + ([index_docs, index], reader, ["text"], None, 200), + (index, reader, ["title"], ["text"], 200) ]: + snippets_param = get_join(snippets) check( client.get( - f"/index/{get_join(ix)}/documents?fields={get_join(fields)}{'&highlight=true' if highlight else ''}", + f"/index/{get_join(ix)}/documents?fields={get_join(fields)}{'&snippets=' + snippets_param if snippets else ''}", headers=build_headers(u), ), outcome, msg=f"Index: {ix}, user: {u}, fields: {fields}", ) body = {"fields": fields} - if highlight: - body["highlight"] = True + if snippets: + body["snippets"] = snippets check( client.post( f"/index/{get_join(ix)}/query", diff --git a/tests/test_elastic.py b/tests/test_elastic.py index a4ee366..d6d271e 100644 --- a/tests/test_elastic.py +++ b/tests/test_elastic.py @@ -43,7 +43,7 @@ def test_fields(index): def test_values(index): """Can we get values for a specific field""" upload(index, [dict(bla=x) for x in ["odd", "even", "even"] * 10], fields={"bla": "keyword"}) - assert set(elastic.get_values(index, "bla")) == {"odd", "even"} + assert set(elastic.get_field_values(index, "bla", 10)) == {"odd", "even"} def test_update(index_docs): From e2842c9952ab90501a32a165eadc199368125ad8 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Mon, 25 Dec 2023 16:21:19 +0100 Subject: [PATCH 02/80] merging all field permissions because we thought it would be cool to allow querying multiple indices --- amcat4/api/query.py | 47 ++++++++++++++++++++----------------- amcat4/elastic.py | 37 +++++++++++++++++++++++++++-- tests/test_api_documents.py | 4 ++-- 3 files changed, 63 insertions(+), 25 deletions(-) diff --git a/amcat4/api/query.py b/amcat4/api/query.py index 7317c7e..de8b0c4 100644 --- a/amcat4/api/query.py +++ b/amcat4/api/query.py @@ -33,26 +33,30 @@ class QueryResult(BaseModel): def _check_query_role( - indices: List[str], user: str, fields: List[str], snippets: Optional[List[str]] = None + indices: List[str], index_fields: dict, user: str, fields: List[str], snippets: Optional[List[str]] = None ): - # TODO: index setting for which fields METAREADERS can see. - # Each field should say both: metareader_visible and metareader_snippet - metareader_visible = ["date", "title", "url"] - metareader_snippet = ["text"] - - all_values_in = lambda a, b: all([x in b for x in a]) - meta_visible = (not fields) or all_values_in(fields, metareader_visible) - meta_visible_snippet = (not snippets) or all_values_in(snippets, metareader_snippet) - - - if (meta_visible and meta_visible_snippet): - role = Role.METAREADER + """ + Check whether the user needs to have metareader or reader role. + The index_fields (from elastic.get_fields) contains meta information about + field access in the index. For multiple indices, the most restritive setting is used. + """ + metareader_visible = index_fields.get("meta", {}).get("metareader_visible", []) + metareader_snippet = index_fields.get("meta", {}).get("metareader_snippet", []) + + def visible_to_metareader(fields, metareader_fields): + if (not fields): + return True + return all([x in metareader_fields for x in fields]) + + meta_visible = visible_to_metareader(fields, metareader_visible) + meta_visible_snippet = visible_to_metareader(snippets, metareader_snippet) + if meta_visible and meta_visible_snippet: + required_role = Role.METAREADER else: - role = Role.READER - - print(role) + required_role = Role.READER + for ix in indices: - check_role(user, role, ix) + check_role(user, required_role, ix) @app_query.get("/{index}/documents", response_model=QueryResult) @@ -115,8 +119,9 @@ def get_documents( for field in fields: if field in snippets: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Field {field} cannot be in both fields and snippets") - - _check_query_role(indices, user, fields, snippets) + + index_fields = elastic.get_fields(indices) + _check_query_role(indices, index_fields, user, fields, snippets) args = {} sort = sort and [ @@ -274,8 +279,8 @@ def query_documents_post( if field in snippets: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Field {field} cannot be in both fields and snippets") - #field_meta = elastic.get_fields(index) - _check_query_role(indices, user, fields, snippets) + index_fields = elastic.get_fields(indices) + _check_query_role(indices, index_fields, user, fields, snippets) queries = _process_queries(queries) filters = dict(_process_filters(filters)) diff --git a/amcat4/elastic.py b/amcat4/elastic.py index 90fc4ff..1afa827 100644 --- a/amcat4/elastic.py +++ b/amcat4/elastic.py @@ -276,7 +276,7 @@ def _get_fields(index: str) -> Iterable[Tuple[str, dict]]: t = dict(name=k, type=_get_type_from_property(v)) if meta := v.get("meta"): t["meta"] = meta - yield k, t + yield k, t def get_index_fields(index: str) -> Mapping[str, dict]: @@ -296,12 +296,44 @@ def get_fields(index: Union[str, Sequence[str]]): """ if isinstance(index, str): return get_index_fields(index) + + def get_meta_value(field, meta_key, default): + return field.get("meta", {}).get(meta_key) or default + result = {} for ix in index: for f, ftype in get_index_fields(ix).items(): if f in result: if result[f] != ftype: - result[f] = {"name": f, "type": "keyword", "meta": {"merged": True}} + # for merged fields, use the most restrictive meta settings + metareader_visible_1: bool = get_meta_value(result[f], "metareader_visible", False) + metareader_visible_2: bool = get_meta_value(ftype, "metareader_visible", False) + metareader_visible: bool = metareader_visible_1 and metareader_visible_2 + + metareader_visible_snippet_1: bool = get_meta_value(result[f], "metareader_visible_snippet", False) + metareader_visible_snippet_2: bool = get_meta_value(ftype, "metareader_visible_snippet", False) + metareader_visible_snippet: bool = metareader_visible_snippet_1 and metareader_visible_snippet_2 + + match_snippets_1: int = get_meta_value(result[f], "query_snippets", 0) + match_snippets_2: int = get_meta_value(ftype, "query_snippets", 0) + match_snippets: int = min(match_snippets_1, match_snippets_2) + + match_snippets_size_1: int = get_meta_value(result[f], "query_snippets_size", 0) + match_snippets_size_2: int = get_meta_value(ftype, "query_snippets_size", 0) + match_snippets_size: int = min(match_snippets_size_1, match_snippets_size_2) + + nomatch_snippet_size_1: int = get_meta_value(result[f], "nomatch_snippet_size", 0) + nomatch_snippet_size_2: int = get_meta_value(ftype, "nomatch_snippet_size", 0) + nomatch_snippet_size: int = min(nomatch_snippet_size_1, nomatch_snippet_size_2) + + result[f] = {"name": f, "type": "keyword", "meta": { + "merged": True, + "metareader_visible": metareader_visible, + "metareader_visible_snippet": metareader_visible_snippet, + "query_snippets": match_snippets, + "query_snippets_size": match_snippets_size, + "nomatch_snippet_size": nomatch_snippet_size, + }} else: result[f] = ftype return result @@ -323,6 +355,7 @@ def get_field_values(index: str, field: str, size: int) -> List[str]: r = es().search(index=index, size=0, aggs=aggs) return [x["key"] for x in r["aggregations"]["unique_values"]["buckets"]] + def get_field_stats(index: str, field: str) -> List[str]: """ Get field statistics, such as min, max, avg, etc. diff --git a/tests/test_api_documents.py b/tests/test_api_documents.py index e401778..914a8de 100644 --- a/tests/test_api_documents.py +++ b/tests/test_api_documents.py @@ -80,10 +80,10 @@ def get_join(x): ([index_docs, index], reader, ["text"], None, 200), (index, reader, ["title"], ["text"], 200) ]: - snippets_param = get_join(snippets) + snippets_param = ("&snippets" + get_join(snippets)) if snippets else "" check( client.get( - f"/index/{get_join(ix)}/documents?fields={get_join(fields)}{'&snippets=' + snippets_param if snippets else ''}", + f"/index/{get_join(ix)}/documents?fields={get_join(fields)}{snippets_param}", headers=build_headers(u), ), outcome, From 4006ab0c6664cadf05973df3dafaea708e7029ce Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Thu, 28 Dec 2023 01:43:28 +0100 Subject: [PATCH 03/80] user management --- amcat4/api/users.py | 4 ++-- amcat4/index.py | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/amcat4/api/users.py b/amcat4/api/users.py index a5efd8a..4eda374 100644 --- a/amcat4/api/users.py +++ b/amcat4/api/users.py @@ -36,10 +36,10 @@ class ChangeUserForm(BaseModel): role: Optional[ROLE] = None -@app_users.post("/users/", status_code=status.HTTP_201_CREATED) +@app_users.post("/users", status_code=status.HTTP_201_CREATED) def create_user(new_user: UserForm, _=Depends(authenticated_admin)): """Create a new user.""" - if get_global_role(new_user.email) is not None: + if get_global_role(new_user.email, only_es=True) is not None: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"User {new_user.email} already exists", diff --git a/amcat4/index.py b/amcat4/index.py index 7477b9b..440b8fa 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -320,15 +320,16 @@ def get_guest_role(index: str) -> Optional[Role]: return Role[role] -def get_global_role(email: str) -> Optional[Role]: +def get_global_role(email: str, only_es: bool = False) -> Optional[Role]: """ Retrieve the global role of this user :returns: a Role object, or None if the user has no role """ # The 'admin' user is given to everyone in the no_auth scenario - if email == get_settings().admin_email or email == "admin": - return Role.ADMIN + if only_es is False: + if email == get_settings().admin_email or email == "admin": + return Role.ADMIN return get_role(index=GLOBAL_ROLES, email=email) From 3c7bf6b63fd2589e437a0ba4a9526ba201dd7012 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Thu, 4 Jan 2024 18:58:10 +0100 Subject: [PATCH 04/80] field stuff --- amcat4/api/auth.py | 127 +++++++++++++++++++++++---- amcat4/api/query.py | 88 ++++++------------- amcat4/elastic.py | 92 ++++++++++--------- amcat4/query.py | 209 ++++++++++++++++++++++++++++++-------------- amcat4/util.py | 26 ++++++ 5 files changed, 359 insertions(+), 183 deletions(-) create mode 100644 amcat4/util.py diff --git a/amcat4/api/auth.py b/amcat4/api/auth.py index ae87683..f4a3c34 100644 --- a/amcat4/api/auth.py +++ b/amcat4/api/auth.py @@ -2,6 +2,7 @@ import functools import logging from datetime import datetime +from typing import Iterable import requests from authlib.common.errors import AuthlibBaseError @@ -11,6 +12,8 @@ from fastapi.security import OAuth2PasswordBearer from starlette.status import HTTP_401_UNAUTHORIZED +from amcat4 import elastic +from amcat4.util import parse_snippet from amcat4.config import get_settings, AuthOptions from amcat4.index import Role, get_role, get_global_role @@ -35,13 +38,15 @@ def verify_token(token: str) -> dict: raises a InvalidToken exception if the token could not be validated """ payload = decode_middlecat_token(token) - if missing := {'email', 'resource', 'exp'} - set(payload.keys()): + if missing := {"email", "resource", "exp"} - set(payload.keys()): raise InvalidToken(f"Invalid token, missing keys {missing}") now = int(datetime.now().timestamp()) - if payload['exp'] < now: + if payload["exp"] < now: raise InvalidToken("Token expired") - if payload['resource'] != get_settings().host: - raise InvalidToken(f"Wrong host! {payload['resource']} != {get_settings().host}") + if payload["resource"] != get_settings().host: + raise InvalidToken( + f"Wrong host! {payload['resource']} != {get_settings().host}" + ) return payload @@ -52,7 +57,7 @@ def decode_middlecat_token(token: str) -> dict: url = get_settings().middlecat_url if not url: raise InvalidToken("No middlecat defined, cannot decrypt middlecat token") - public_key = get_middlecat_config(url)['public_key'] + public_key = get_middlecat_config(url)["public_key"] try: return jwt.decode(token, public_key) except AuthlibBaseError as e: @@ -72,17 +77,24 @@ def check_global_role(user: str, required_role: Role, raise_error=True): try: global_role = get_global_role(user) except Exception as e: - raise HTTPException(status_code=500, detail=f"Error on retrieving user {user}: {e}") + raise HTTPException( + status_code=500, detail=f"Error on retrieving user {user}: {e}" + ) if global_role and global_role >= required_role: return global_role if raise_error: - raise HTTPException(status_code=401, detail=f"User {user} does not have global " - f"{required_role.name.title()} permissions on this instance") + raise HTTPException( + status_code=401, + detail=f"User {user} does not have global " + f"{required_role.name.title()} permissions on this instance", + ) else: return False -def check_role(user: str, required_role: Role, index: str, required_global_role: Role = Role.ADMIN): +def check_role( + user: str, required_role: Role, index: str, required_global_role: Role = Role.ADMIN +): """Check if the given user have at least the given role (in the index, if given), raise Exception otherwise. :param user: The email address of the authenticated user @@ -101,8 +113,87 @@ def check_role(user: str, required_role: Role, index: str, required_global_role: elif actual_role and actual_role >= required_role: return actual_role else: - raise HTTPException(status_code=401, detail=f"User {user} does not have " - f"{required_role.name.title()} permissions on index {index}") + raise HTTPException( + status_code=401, + detail=f"User {user} does not have " + f"{required_role.name.title()} permissions on index {index}", + ) + + +def check_query_allowed( + index: str, user: str, fields: Iterable[str] = None, snippets: Iterable[str] = None +) -> None: + """Check if the given user is allowed to query the given fields and snippets on the given index. + + :param index: The index to check the role on + :param user: The email address of the authenticated user + :param fields: The fields to check + :param snippets: The snippets to check + :return: True if the user is allowed to query the given fields and snippets, False otherwise + """ + role = get_role(index, user) + if role is None: + raise HTTPException( + status_code=401, + detail=f"User {user} does not have a role on index {index}", + ) + if role >= Role.READER: + return True + + # after this, we know the user is a metareader, so we need to check metareader_access + + def check_fields_access(fields, index_fields) -> None: + if fields is None: + return None + + for field in fields: + if field not in index_fields: + continue + field_meta = index_fields[field].get("meta", {}) + metareader_access = field_meta.get("metareader_access", None) + if metareader_access != "read": + raise HTTPException( + status_code=401, + detail=f"METAREADER cannot read {field} on index {index}", + ) + + def check_snippets_access(snippets, index_fields) -> None: + if snippets is None: + return None + + for snippet in snippets: + field, nomatch_chars, max_matches, match_chars = parse_snippet(snippet) + if field not in index_fields: + continue + field_meta = index_fields[field].get("meta", {}) + metareader_access = field_meta.get("metareader_access", None) + if metareader_access is None: + raise HTTPException( + status_code=401, + detail=f"METAREADER cannot read snippet of {field} on index {index}", + ) + if "snippet" in metareader_access: + ( + _, + meta_nomatch_chars, + meta_max_matches, + meta_match_chars, + ) = parse_snippet(metareader_access) + valid_nomatch_chars = nomatch_chars <= meta_nomatch_chars + valid_max_matches = max_matches <= meta_max_matches + valid_match_chars = match_chars <= meta_match_chars + valid = valid_nomatch_chars and valid_max_matches and valid_match_chars + if not valid: + max_params = f"{field}[{meta_nomatch_chars};{meta_max_matches};{meta_match_chars}]" + raise HTTPException( + status_code=401, + detail=f"The requested snippet of {field} on index {index} is too long. " + f"max parameters are: {max_params}", + ) + + index_fields = elastic.get_fields(index) + check_fields_access(fields, index_fields) + check_snippets_access(snippets, index_fields) async def authenticated_user(token: str = Depends(oauth2_scheme)) -> str: @@ -114,17 +205,21 @@ async def authenticated_user(token: str = Depends(oauth2_scheme)) -> str: elif auth == AuthOptions.allow_guests: return "guest" else: - raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, - detail="This instance has no guest access, please provide a valid bearer token") + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="This instance has no guest access, please provide a valid bearer token", + ) try: - user = verify_token(token)['email'] + user = verify_token(token)["email"] except Exception: logging.exception("Login failed") raise HTTPException(status_code=401, detail="Invalid token") if auth == AuthOptions.authorized_users_only: if get_global_role(user) is None: - raise HTTPException(status_code=401, - detail=f"The user {user} is not authorized to access this AmCAT instance") + raise HTTPException( + status_code=401, + detail=f"The user {user} is not authorized to access this AmCAT instance", + ) return user diff --git a/amcat4/api/query.py b/amcat4/api/query.py index de8b0c4..eabf8a4 100644 --- a/amcat4/api/query.py +++ b/amcat4/api/query.py @@ -8,9 +8,10 @@ from amcat4 import elastic, query, aggregate from amcat4.aggregate import Axis, Aggregation -from amcat4.api.auth import authenticated_user, check_role +from amcat4.api.auth import authenticated_user, check_query_allowed from amcat4.index import Role from amcat4.query import update_tag_query +from amcat4.util import parse_snippet app_query = APIRouter(prefix="/index", tags=["query"]) @@ -32,33 +33,6 @@ class QueryResult(BaseModel): meta: QueryMeta -def _check_query_role( - indices: List[str], index_fields: dict, user: str, fields: List[str], snippets: Optional[List[str]] = None -): - """ - Check whether the user needs to have metareader or reader role. - The index_fields (from elastic.get_fields) contains meta information about - field access in the index. For multiple indices, the most restritive setting is used. - """ - metareader_visible = index_fields.get("meta", {}).get("metareader_visible", []) - metareader_snippet = index_fields.get("meta", {}).get("metareader_snippet", []) - - def visible_to_metareader(fields, metareader_fields): - if (not fields): - return True - return all([x in metareader_fields for x in fields]) - - meta_visible = visible_to_metareader(fields, metareader_visible) - meta_visible_snippet = visible_to_metareader(snippets, metareader_snippet) - if meta_visible and meta_visible_snippet: - required_role = Role.METAREADER - else: - required_role = Role.READER - - for ix in indices: - check_role(user, required_role, ix) - - @app_query.get("/{index}/documents", response_model=QueryResult) def get_documents( index: str, @@ -81,13 +55,17 @@ def get_documents( ), snippets: str = Query( None, - description="Comma separated list of fields to return as snippets", - pattern=r"\w+(,\w+)*", - ), - highlight: bool = Query( - False, - description="If true, highlight fields" + description="Comma separated list of fields to return as snippets. If only field names are given, the default " + "snippet parameters are used. The parameters are 'nomatch_chars' (default: 150), 'max_matches' (default: 3) " + "and 'match_chars' (default: 50). If there is no query, the snippet is the first [nomatch_chars] characters. " + "If there is a query, snippets are returned for up to [max_matches] matches, with each match having [match_chars] " + "characters. match snippets are concatenated with ' ... ' and have tags around the matched text. " + "If you want to use custom snippet parameters, you can add a suffix to the field name with the parameters between " + "brackets, in the format: fieldname[nomatch_chars;max_matches;match_chars] (e.g, text[150;3;50]). " + "(always provide all 3 parameters, even if you only want to change one)", + pattern=r"[\w\[;\]]+(,[\w\[;\]]+)*", ), + highlight: bool = Query(False, description="If true, highlight fields"), per_page: int = Query(None, description="Number of results per page"), page: int = Query(None, description="Page to fetch"), scroll: str = Query( @@ -113,16 +91,10 @@ def get_documents( fields = fields and fields.split(",") if not fields: fields = ["date", "title", "url"] - - snippets = snippets and snippets.split(",") - if snippets: - for field in fields: - if field in snippets: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Field {field} cannot be in both fields and snippets") - - index_fields = elastic.get_fields(indices) - _check_query_role(indices, index_fields, user, fields, snippets) - + + for index in indices: + check_query_allowed(indices, user, fields, snippets) + args = {} sort = sort and [ {x.replace(":desc", ""): "desc"} if x.endswith(":desc") else x @@ -215,7 +187,15 @@ def query_documents_post( None, description="List of fields to retrieve for each document" ), snippets: Optional[List[str]] = Body( - None, description="Fields to retrieve as snippets" + None, + description="Fields to retrieve as snippets. If only field names are given, the default " + "snippet parameters are used. The parameters are [nomatch_chars] (default: 200), [max_matches] (default: 3) " + "and [match_chars] (default: 50). If there is no query, the snippet is the first [nomatch_chars] characters. " + "If there is a query, snippets are returned for up to [max_matches] matches, with each match having [match_chars] " + "characters. match snippets also have tags around the matched text. " + "If you want to use custom snippet parameters, you can add a suffix to the field name with the parameters between " + "brackets, in the format: fieldname[nomatch_chars;max_matches;match_chars] (e.g, text[150;3;50]). " + "(always provide all 3 parameters, even if you only want to change one)", ), filters: Optional[ Dict[str, Union[FilterValue, List[FilterValue], FilterSpec]] @@ -252,10 +232,7 @@ def query_documents_post( scroll_id: Optional[str] = Body( None, description="Scroll id from previous response to continue scrolling" ), - highlight: Optional[bool] = Body( - False, - description="If true, highlight fields" - ), + highlight: Optional[bool] = Body(False, description="If true, highlight fields"), user=Depends(authenticated_user), ): """ @@ -271,16 +248,9 @@ def query_documents_post( fields = [fields] else: fields = ["date", "title", "url"] - - if snippets: - if isinstance(snippets, str): - snippets = [snippets] - for field in fields: - if field in snippets: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Field {field} cannot be in both fields and snippets") - - index_fields = elastic.get_fields(indices) - _check_query_role(indices, index_fields, user, fields, snippets) + + for index in indices: + check_query_allowed(index, user, fields, snippets) queries = _process_queries(queries) filters = dict(_process_filters(filters)) diff --git a/amcat4/elastic.py b/amcat4/elastic.py index 1afa827..f844237 100644 --- a/amcat4/elastic.py +++ b/amcat4/elastic.py @@ -19,6 +19,7 @@ from elasticsearch.helpers import bulk from amcat4.config import get_settings +from amcat4.util import parse_snippet SYSTEM_INDEX_VERSION = 1 @@ -43,6 +44,7 @@ "url": ES_MAPPINGS["url"], } + SYSTEM_MAPPING = { "name": {"type": "text"}, "description": {"type": "text"}, @@ -276,7 +278,7 @@ def _get_fields(index: str) -> Iterable[Tuple[str, dict]]: t = dict(name=k, type=_get_type_from_property(v)) if meta := v.get("meta"): t["meta"] = meta - yield k, t + yield k, t def get_index_fields(index: str) -> Mapping[str, dict]: @@ -288,7 +290,7 @@ def get_index_fields(index: str) -> Mapping[str, dict]: return dict(_get_fields(index)) -def get_fields(index: Union[str, Sequence[str]]): +def get_fields(index: Union[str, Sequence[str]]) -> Mapping[str, dict]: """ Get the field types in use in this index or indices :param index: name(s) of index(es) to query @@ -296,44 +298,54 @@ def get_fields(index: Union[str, Sequence[str]]): """ if isinstance(index, str): return get_index_fields(index) - - def get_meta_value(field, meta_key, default): - return field.get("meta", {}).get(meta_key) or default - + + # def get_meta_value(field, meta_key, default): + # return field.get("meta", {}).get(meta_key) or default + + # def get_least_metareader_access(access1, access2): + # if (access1 == None) or (access2 == None): + # return None + + # if "snippet" in access1 and access2 == "read": + # return access1 + + # if "snippet" in access2 and access1 == "read": + # return access2 + + # if "snippet" in access1 and "snippet" in access2: + # _, nomatch_chars1, max_matches1, match_chars1 = parse_snippet(access1) + # _, nomatch_chars2, max_matches2, match_chars2 = parse_snippet(access2) + # nomatch_chars = min(nomatch_chars1, nomatch_chars2) + # max_matches = min(max_matches1, max_matches2) + # match_chars = match_chars1 + match_chars2 + # return f"snippet[{nomatch_chars},{max_matches},{match_chars}]" + + # if access1 == "read" and access2 == "read": + # return "read" + result = {} for ix in index: for f, ftype in get_index_fields(ix).items(): if f in result: if result[f] != ftype: - # for merged fields, use the most restrictive meta settings - metareader_visible_1: bool = get_meta_value(result[f], "metareader_visible", False) - metareader_visible_2: bool = get_meta_value(ftype, "metareader_visible", False) - metareader_visible: bool = metareader_visible_1 and metareader_visible_2 - - metareader_visible_snippet_1: bool = get_meta_value(result[f], "metareader_visible_snippet", False) - metareader_visible_snippet_2: bool = get_meta_value(ftype, "metareader_visible_snippet", False) - metareader_visible_snippet: bool = metareader_visible_snippet_1 and metareader_visible_snippet_2 - - match_snippets_1: int = get_meta_value(result[f], "query_snippets", 0) - match_snippets_2: int = get_meta_value(ftype, "query_snippets", 0) - match_snippets: int = min(match_snippets_1, match_snippets_2) - - match_snippets_size_1: int = get_meta_value(result[f], "query_snippets_size", 0) - match_snippets_size_2: int = get_meta_value(ftype, "query_snippets_size", 0) - match_snippets_size: int = min(match_snippets_size_1, match_snippets_size_2) - - nomatch_snippet_size_1: int = get_meta_value(result[f], "nomatch_snippet_size", 0) - nomatch_snippet_size_2: int = get_meta_value(ftype, "nomatch_snippet_size", 0) - nomatch_snippet_size: int = min(nomatch_snippet_size_1, nomatch_snippet_size_2) - - result[f] = {"name": f, "type": "keyword", "meta": { - "merged": True, - "metareader_visible": metareader_visible, - "metareader_visible_snippet": metareader_visible_snippet, - "query_snippets": match_snippets, - "query_snippets_size": match_snippets_size, - "nomatch_snippet_size": nomatch_snippet_size, - }} + # note that for merged fields metareader access is always None + # metareader_access_1: bool = get_meta_value( + # result[f], "metareader_visible", None + # ) + # metareader_access_2: bool = get_meta_value( + # ftype, "metareader_visible", None + # ) + # metareader_access = get_least_metareader_access( + # metareader_access_1, metareader_access_2 + # ) + + result[f] = { + "name": f, + "type": "keyword", + "meta": { + "merged": True, + }, + } else: result[f] = ftype return result @@ -342,23 +354,21 @@ def get_meta_value(field, meta_key, default): def get_field_values(index: str, field: str, size: int) -> List[str]: """ Get the values for a given field (e.g. to populate list of filter values on keyword field) - Results are sorted descending by document frequency + Results are sorted descending by document frequency see: https://www.elastic.co/guide/en/elasticsearch/reference/7.4/search-aggregations-bucket-terms-aggregation.html#search-aggregations-bucket-terms-aggregation-order - + :param index: The index :param field: The field name :return: A list of values """ - aggs = {"unique_values": { - "terms": {"field": field, "size": size} - }} + aggs = {"unique_values": {"terms": {"field": field, "size": size}}} r = es().search(index=index, size=0, aggs=aggs) return [x["key"] for x in r["aggregations"]["unique_values"]["buckets"]] def get_field_stats(index: str, field: str) -> List[str]: """ - Get field statistics, such as min, max, avg, etc. + Get field statistics, such as min, max, avg, etc. :param index: The index :param field: The field name :return: A list of values diff --git a/amcat4/query.py b/amcat4/query.py index b60ab16..2483b3a 100644 --- a/amcat4/query.py +++ b/amcat4/query.py @@ -2,55 +2,73 @@ All things query """ from math import ceil -import json -from re import finditer -from re import sub -from typing import Mapping, Iterable, Optional, Union, Sequence, Any, Dict, List, Tuple, Literal + +from typing import ( + Mapping, + Iterable, + Optional, + Union, + Sequence, + Any, + Dict, + List, + Tuple, + Literal, +) from .date_mappings import mappings from .elastic import es, update_tag_by_query +from amcat4.util import parse_snippet -def build_body(queries: Iterable[str] = None, filters: Mapping = None, highlight: dict = None, - ids: Iterable[str] = None): +def build_body( + queries: Iterable[str] = None, + filters: Mapping = None, + highlight: dict = None, + ids: Iterable[str] = None, +): def parse_filter(field, filter) -> Tuple[Mapping, Mapping]: filter = filter.copy() extra_runtime_mappings = {} field_filters = [] - for value in filter.pop('values', []): + for value in filter.pop("values", []): field_filters.append({"term": {field: value}}) - if 'value' in filter: - field_filters.append({"term": {field: filter.pop('value')}}) - if 'exists' in filter: - if filter.pop('exists'): + if "value" in filter: + field_filters.append({"term": {field: filter.pop("value")}}) + if "exists" in filter: + if filter.pop("exists"): field_filters.append({"exists": {"field": field}}) else: - field_filters.append({"bool": {"must_not": {"exists": {"field": field}}}}) + field_filters.append( + {"bool": {"must_not": {"exists": {"field": field}}}} + ) for mapping in mappings(): if mapping.interval in filter: value = filter.pop(mapping.interval) extra_runtime_mappings.update(mapping.mapping(field)) field_filters.append({"term": {mapping.fieldname(field): value}}) rangefilter = {} - for rangevar in ['gt', 'gte', 'lt', 'lte']: + for rangevar in ["gt", "gte", "lt", "lte"]: if rangevar in filter: rangefilter[rangevar] = filter.pop(rangevar) if rangefilter: field_filters.append({"range": {field: rangefilter}}) if filter: raise ValueError(f"Unknown filter type(s): {filter}") - return extra_runtime_mappings, {'bool': {'should': field_filters}} + return extra_runtime_mappings, {"bool": {"should": field_filters}} def parse_query(q: str) -> dict: - return {"query_string": {"query": q}} + return {"query_string": {"query": q}} def parse_queries(qs: Sequence[str]) -> dict: if len(qs) == 1: return parse_query(list(qs)[0]) else: return {"bool": {"should": [parse_query(q) for q in qs]}} - if not (queries or filters or ids): - return {'query': {'match_all': {}}} + + if not (queries or filters or ids or highlight): + return {"query": {"match_all": {}}} + fs, runtime_mappings = [], {} if filters: for field, filter in filters.items(): @@ -66,17 +84,24 @@ def parse_queries(qs: Sequence[str]) -> dict: fs.append({"ids": {"values": list(ids)}}) body: Dict[str, Any] = {"query": {"bool": {"filter": fs}}} if runtime_mappings: - body['runtime_mappings'] = runtime_mappings - + body["runtime_mappings"] = runtime_mappings + if highlight is not None: body["highlight"] = highlight - + return body class QueryResult: - def __init__(self, data: List[dict], - n: int = None, per_page: int = None, page: int = None, page_count: int = None, scroll_id: str = None): + def __init__( + self, + data: List[dict], + n: int = None, + per_page: int = None, + page: int = None, + page_count: int = None, + scroll_id: str = None, + ): if n and (page_count is None) and (per_page is not None): page_count = ceil(n / per_page) self.data = data @@ -87,18 +112,21 @@ def __init__(self, data: List[dict], self.scroll_id = scroll_id def as_dict(self): - meta = {"total_count": self.total_count, - "per_page": self.per_page, - "page_count": self.page_count, - } + meta = { + "total_count": self.total_count, + "per_page": self.per_page, + "page_count": self.page_count, + } if self.scroll_id: - meta['scroll_id'] = self.scroll_id + meta["scroll_id"] = self.scroll_id else: - meta['page'] = self.page + meta["page"] = self.page return dict(meta=meta, results=self.data) -def _normalize_queries(queries: Optional[Union[Dict[str, str], Iterable[str]]]) -> Mapping[str, str]: +def _normalize_queries( + queries: Optional[Union[Dict[str, str], Iterable[str]]] +) -> Mapping[str, str]: if queries is None: return {} if isinstance(queries, dict): @@ -106,14 +134,21 @@ def _normalize_queries(queries: Optional[Union[Dict[str, str], Iterable[str]]]) return {q: q for q in queries} -def query_documents(index: Union[str, Sequence[str]], queries: Union[Mapping[str, str], Iterable[str]] = None, *, - page: int = 0, per_page: int = 10, - scroll=None, scroll_id: str = None, - fields: Iterable[str] = None, snippets: Iterable[str] = None, - filters: Mapping[str, Mapping] = None, - highlight: Literal["none", "text", "snippets"] = "none", - sort: List[Union[str, Mapping]] = None, - **kwargs) -> Optional[QueryResult]: +def query_documents( + index: Union[str, Sequence[str]], + queries: Union[Mapping[str, str], Iterable[str]] = None, + *, + page: int = 0, + per_page: int = 10, + scroll=None, + scroll_id: str = None, + fields: Iterable[str] = None, + snippets: Iterable[str] = None, + filters: Mapping[str, Mapping] = None, + highlight: Literal["none", "text", "snippets"] = "none", + sort: List[Union[str, Mapping]] = None, + **kwargs, +) -> Optional[QueryResult]: """ Conduct a query_string query, returning the found documents. @@ -142,42 +177,54 @@ def query_documents(index: Union[str, Sequence[str]], queries: Union[Mapping[str :param kwargs: Additional elements passed to Elasticsearch.search() :return: a QueryResult, or None if there is not scroll result anymore """ + if overlap_fields_snippets(fields, snippets): + raise ValueError("Cannot request both a field AND snippets for this field") + if scroll or scroll_id: # set scroll to default also if scroll_id is given but no scroll time is known - kwargs['scroll'] = '2m' if (not scroll or scroll is True) else scroll + kwargs["scroll"] = "2m" if (not scroll or scroll is True) else scroll queries = _normalize_queries(queries) if sort is not None: kwargs["sort"] = sort if scroll_id: result = es().scroll(scroll_id=scroll_id, **kwargs) - if not result['hits']['hits']: + if not result["hits"]["hits"]: return None else: h = query_highlight(fields, highlight, snippets) body = build_body(queries.values(), filters, h) - + if fields: fields = fields if isinstance(fields, list) else list(fields) - kwargs['_source'] = fields + kwargs["_source"] = fields if not scroll: - kwargs['from_'] = page * per_page + kwargs["from_"] = page * per_page result = es().search(index=index, size=per_page, **body, **kwargs) data = [] - for hit in result['hits']['hits']: - hitdict = dict(_id=hit['_id'], **hit['_source']) + for hit in result["hits"]["hits"]: + hitdict = dict(_id=hit["_id"], **hit["_source"]) hitdict = overwrite_highlight_results(hit, hitdict) - if 'highlight' in hit: - for key in hit['highlight'].keys(): - if hit['highlight'][key]: - hitdict[key] = " ... ".join(hit['highlight'][key]) + if "highlight" in hit: + for key in hit["highlight"].keys(): + if hit["highlight"][key]: + hitdict[key] = " ... ".join(hit["highlight"][key]) data.append(hitdict) if scroll_id: - return QueryResult(data, n=result['hits']['total']['value'], scroll_id=result['_scroll_id']) + return QueryResult( + data, n=result["hits"]["total"]["value"], scroll_id=result["_scroll_id"] + ) elif scroll: - return QueryResult(data, n=result['hits']['total']['value'], per_page=per_page, scroll_id=result['_scroll_id']) + return QueryResult( + data, + n=result["hits"]["total"]["value"], + per_page=per_page, + scroll_id=result["_scroll_id"], + ) else: - return QueryResult(data, n=result['hits']['total']['value'], per_page=per_page, page=page) + return QueryResult( + data, n=result["hits"]["total"]["value"], per_page=per_page, page=page + ) def query_highlight(fields: Iterable[str], highlight: bool, snippets: Iterable[str]): @@ -187,38 +234,66 @@ def query_highlight(fields: Iterable[str], highlight: bool, snippets: Iterable[s """ if (fields is None or highlight is False) and (snippets is None): return None - + highlight = {"require_field_match": False, "fields": {}} - + if fields is not None: for field in fields: highlight["fields"][field] = {"number_of_fragments": 0} - + if snippets is not None: # TODO: get index meta data to see which snippets are allowed and what # the nr and size should be - for field in snippets: - highlight["fields"][field] = {"no_match_size": 200, "number_of_fragments": 3, "fragment_size": 40} - + for snippet in snippets: + field, nomatch_chars, max_matches, match_chars = parse_snippet(snippet) + highlight["fields"][field] = { + "no_match_size": nomatch_chars, + "number_of_fragments": max_matches, + "fragment_size": match_chars, + } + return highlight + def overwrite_highlight_results(hit: dict, hitdict: dict): """ highlights are a separate field in the hits. If highlight is True, we want to overwrite the original field with the highlighted version. If there are snippets, we want to add them """ - if not hit.get('highlight'): + if not hit.get("highlight"): return hitdict - for key in hit['highlight'].keys(): - if hit['highlight'][key]: - hitdict[key] = " ... ".join(hit['highlight'][key]) + for key in hit["highlight"].keys(): + if hit["highlight"][key]: + hitdict[key] = " ... ".join(hit["highlight"][key]) return hitdict - -def update_tag_query(index: Union[str, Sequence[str]], action: Literal["add", "remove"], - field: str, tag: str, - queries: Union[Mapping[str, str], Iterable[str]] = None, - filters: Mapping[str, Mapping] = None, - ids: Sequence[str] = None): + + +def update_tag_query( + index: Union[str, Sequence[str]], + action: Literal["add", "remove"], + field: str, + tag: str, + queries: Union[Mapping[str, str], Iterable[str]] = None, + filters: Mapping[str, Mapping] = None, + ids: Sequence[str] = None, +): """Add or remove tags using a query""" body = build_body(queries and queries.values(), filters, ids=ids) update_tag_by_query(index, action, body, field, tag) + + +def overlap_fields_snippets( + fields: Iterable[str] = None, snippets: Iterable[str] = None +) -> bool: + """ + If both fields and snippets are requested as output, check if there are any overlaps + """ + if fields is None or snippets is None: + return False + + for snippet in snippets: + field, _, _, _ = parse_snippet(snippet) + if field in fields: + return True + + return False diff --git a/amcat4/util.py b/amcat4/util.py new file mode 100644 index 0000000..ee1af8a --- /dev/null +++ b/amcat4/util.py @@ -0,0 +1,26 @@ +import re +from typing import Tuple + + +def parse_snippet(snippet: str) -> Tuple[str, int, int, int]: + """ + Parse a snippet string into a field and the snippet parameters. + The format is fieldname[nomatch_chars;max_matches;match_chars]. + If the snippet does not contain parameters (or the specification is wrong), + we assume the snippet is just the field name and use default values. + """ + pattern = r"\[([0-9]+);([0-9]+);([0-9]+)]$" + match = re.match(pattern, snippet) + + if match: + field = snippet[: match.start()] + nomatch_chars = int(match.group(1)) + max_matches = int(match.group(2)) + match_chars = int(match.group(3)) + else: + field = snippet + nomatch_chars = 200 + max_matches = 3 + match_chars = 50 + + return field, nomatch_chars, max_matches, match_chars From 14790d4467bc7ad0f9cd56aae93e9d89d7843565 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Fri, 5 Jan 2024 09:16:40 +0100 Subject: [PATCH 05/80] metareader_access seems to work --- amcat4/elastic.py | 35 ++++++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/amcat4/elastic.py b/amcat4/elastic.py index f844237..98d0972 100644 --- a/amcat4/elastic.py +++ b/amcat4/elastic.py @@ -24,17 +24,30 @@ SYSTEM_INDEX_VERSION = 1 ES_MAPPINGS = { - "long": {"type": "long"}, - "date": {"type": "date", "format": "strict_date_optional_time"}, - "double": {"type": "double"}, - "keyword": {"type": "keyword"}, - "url": {"type": "keyword", "meta": {"amcat4_type": "url"}}, - "tag": {"type": "keyword", "meta": {"amcat4_type": "tag"}}, - "id": {"type": "keyword", "meta": {"amcat4_type": "id"}}, - "text": {"type": "text"}, - "object": {"type": "object"}, - "geo_point": {"type": "geo_point"}, - "dense_vector": {"type": "dense_vector"}, + "long": {"type": "long", "meta": {"metareader_access": "read"}}, + "date": { + "type": "date", + "format": "strict_date_optional_time", + "meta": {"metareader_access": "read"}, + }, + "double": {"type": "double", "meta": {"metareader_access": "read"}}, + "keyword": {"type": "keyword", "meta": {"metareader_access": "read"}}, + "url": { + "type": "keyword", + "meta": {"amcat4_type": "url", "metareader_access": "read"}, + }, + "tag": { + "type": "keyword", + "meta": {"amcat4_type": "tag", "metareader_access": "read"}, + }, + "id": { + "type": "keyword", + "meta": {"amcat4_type": "id", "metareader_access": "read"}, + }, + "text": {"type": "text", "meta": {"metareader_access": "none"}}, + "object": {"type": "object", "meta": {"metareader_access": "none"}}, + "geo_point": {"type": "geo_point", "metareader_access": "read"}, + "dense_vector": {"type": "dense_vector", "metareader_access": "none"}, } DEFAULT_MAPPING = { From 3310f45307dc7b7e2fadf21c1c5222423c17119c Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Sat, 6 Jan 2024 12:52:25 +0100 Subject: [PATCH 06/80] added tests --- amcat4/api/auth.py | 24 +++++---- amcat4/api/query.py | 10 ++-- amcat4/elastic.py | 37 ++++++++++++- amcat4/query.py | 15 ++++-- amcat4/util.py | 2 +- tests/test_api_documents.py | 54 ------------------- tests/test_api_metareader.py | 102 +++++++++++++++++++++++++++++++++++ tests/test_api_pagination.py | 4 +- tests/test_api_query.py | 3 +- tests/test_query.py | 31 ++++++----- 10 files changed, 194 insertions(+), 88 deletions(-) create mode 100644 tests/test_api_metareader.py diff --git a/amcat4/api/auth.py b/amcat4/api/auth.py index f4a3c34..f4ecb5c 100644 --- a/amcat4/api/auth.py +++ b/amcat4/api/auth.py @@ -129,7 +129,7 @@ def check_query_allowed( :param user: The email address of the authenticated user :param fields: The fields to check :param snippets: The snippets to check - :return: True if the user is allowed to query the given fields and snippets, False otherwise + :return: Nothing. Throws HTTPException if the user is not allowed to query the given fields and snippets. """ role = get_role(index, user) if role is None: @@ -147,7 +147,7 @@ def check_fields_access(fields, index_fields) -> None: return None for field in fields: - if field not in index_fields: + if field not in index_fields.keys(): continue field_meta = index_fields[field].get("meta", {}) metareader_access = field_meta.get("metareader_access", None) @@ -163,16 +163,15 @@ def check_snippets_access(snippets, index_fields) -> None: for snippet in snippets: field, nomatch_chars, max_matches, match_chars = parse_snippet(snippet) - if field not in index_fields: + if field not in index_fields.keys(): continue + field_meta = index_fields[field].get("meta", {}) - metareader_access = field_meta.get("metareader_access", None) - if metareader_access is None: - raise HTTPException( - status_code=401, - detail=f"METAREADER cannot read snippet of {field} on index {index}", - ) - if "snippet" in metareader_access: + metareader_access = field_meta.get("metareader_access", "none") + + if metareader_access == "read": + continue + elif "snippet" in metareader_access: ( _, meta_nomatch_chars, @@ -190,6 +189,11 @@ def check_snippets_access(snippets, index_fields) -> None: detail=f"The requested snippet of {field} on index {index} is too long. " f"max parameters are: {max_params}", ) + else: + raise HTTPException( + status_code=401, + detail=f"METAREADER cannot read snippet of {field} on index {index}", + ) index_fields = elastic.get_fields(index) check_fields_access(fields, index_fields) diff --git a/amcat4/api/query.py b/amcat4/api/query.py index eabf8a4..0c7e4d4 100644 --- a/amcat4/api/query.py +++ b/amcat4/api/query.py @@ -89,11 +89,12 @@ def get_documents( """ indices = index.split(",") fields = fields and fields.split(",") - if not fields: + snippets = snippets and snippets.split(",") + if not fields and not snippets: fields = ["date", "title", "url"] for index in indices: - check_query_allowed(indices, user, fields, snippets) + check_query_allowed(index, user, fields, snippets) args = {} sort = sort and [ @@ -246,7 +247,10 @@ def query_documents_post( if fields: if isinstance(fields, str): fields = [fields] - else: + if snippets: + if isinstance(snippets, str): + snippets = [snippets] + if not fields and not snippets: fields = ["date", "title", "url"] for index in indices: diff --git a/amcat4/elastic.py b/amcat4/elastic.py index 98d0972..e81668d 100644 --- a/amcat4/elastic.py +++ b/amcat4/elastic.py @@ -13,6 +13,7 @@ import hashlib import json import logging +import re from typing import Mapping, List, Iterable, Optional, Tuple, Union, Sequence, Literal from elasticsearch import Elasticsearch, NotFoundError @@ -226,10 +227,44 @@ def get_field_mapping(type_: Union[str, dict]): meta = mapping.get("meta", {}) if m := type_.get("meta"): meta.update(m) - mapping["meta"] = meta + mapping["meta"] = validate_field_meta(meta) return mapping +def validate_field_meta(meta: dict): + """ + Elastic has limited available field meta. Here we validate the allowed keys (and values) + """ + valid_fields = ["amcat4_type", "metareader_access", "client_display"] + + for meta_field in meta.keys(): + # Validate keys + if meta_field not in valid_fields: + raise ValueError(f"Invalid meta field: {meta_field}") + + # Validate values + if not isinstance(meta[meta_field], str): + raise ValueError("Meta field value has to be a string") + + if meta_field == "amcat4_type": + if meta[meta_field] not in ES_MAPPINGS.keys(): + raise ValueError(f"Invalid amcat4_type value: {meta[meta_field]}") + + if meta_field == "client_display": + # client_display only concerns the client + continue + + if meta_field == "metareader_access": + # metareader_access can be "none", "read", or "snippet" + # if snippet, can also include the maximum snippet parameters (nomatch_chars, max_matches, match_chars) + # in the format: snippet[nomatch_chars;max_matches;match_chars] + reg = r"^(read|none|snippet(\[\d+;\d+;\d+\])?)$" + if not re.match(reg, meta[meta_field]): + raise ValueError(f"Invalid metareader_access value: {meta[meta_field]}") + + return meta + + def set_fields(index: str, fields: Mapping[str, str]): """ Update the column types for this index diff --git a/amcat4/query.py b/amcat4/query.py index 2483b3a..2de0ca1 100644 --- a/amcat4/query.py +++ b/amcat4/query.py @@ -177,8 +177,12 @@ def query_documents( :param kwargs: Additional elements passed to Elasticsearch.search() :return: a QueryResult, or None if there is not scroll result anymore """ + if fields is not None and not isinstance(fields, list): + raise ValueError("fields should be a list") + if snippets is not None and not isinstance(snippets, list): + raise ValueError("snippets should be a list") if overlap_fields_snippets(fields, snippets): - raise ValueError("Cannot request both a field AND snippets for this field") + raise ValueError("Cannot request a field AND it's snippet at the same time") if scroll or scroll_id: # set scroll to default also if scroll_id is given but no scroll time is known @@ -232,10 +236,15 @@ def query_highlight(fields: Iterable[str], highlight: bool, snippets: Iterable[s The elastic "highlight" parameters works for both highlighting text fields and adding snippets. This function will return the highlight parameter to be added to the query body. """ - if (fields is None or highlight is False) and (snippets is None): + if highlight is False and snippets is None: return None - highlight = {"require_field_match": False, "fields": {}} + highlight = { + "pre_tags": [""] if highlight is True else [""], + "post_tags": [""] if highlight is True else [""], + "require_field_match": True, + "fields": {}, + } if fields is not None: for field in fields: diff --git a/amcat4/util.py b/amcat4/util.py index ee1af8a..39ce16d 100644 --- a/amcat4/util.py +++ b/amcat4/util.py @@ -10,7 +10,7 @@ def parse_snippet(snippet: str) -> Tuple[str, int, int, int]: we assume the snippet is just the field name and use default values. """ pattern = r"\[([0-9]+);([0-9]+);([0-9]+)]$" - match = re.match(pattern, snippet) + match = re.search(pattern, snippet) if match: field = snippet[: match.start()] diff --git a/tests/test_api_documents.py b/tests/test_api_documents.py index 914a8de..ea21d5c 100644 --- a/tests/test_api_documents.py +++ b/tests/test_api_documents.py @@ -1,5 +1,4 @@ from amcat4.index import set_role, Role -from tests.conftest import populate_index from tests.tools import post_json, build_headers, get_json, check @@ -48,56 +47,3 @@ def test_documents(client, index, user): assert get_json(client, url, user=user)["title"] == "the headline" check(client.delete(url, headers=build_headers(user)), 204) check(client.get(url, headers=build_headers(user)), 404) - - -def test_metareader(client, index, index_docs, user, reader): - set_role(index, user, Role.METAREADER) - set_role(index, reader, Role.READER) - populate_index(index) - - r = get_json( - client, - f"/index/{index}/documents?fields=title", - headers=build_headers(user), - ) - _id = r["results"][0]["_id"] - url = f"index/{index}/documents/{_id}" - # Metareader should not be able to retrieve document source - check(client.get(url, headers=build_headers(user)), 401) - check(client.get(url, headers=build_headers(reader)), 200) - - def get_join(x): - return ",".join(x) if isinstance(x, list) else x - - # Metareader should not be able to query text as a field. Only as a snippet - for ix, u, fields, snippets, outcome in [ - (index, user, ["text"], None, 401), - (index_docs, user, ["text"], None, 200), - ([index_docs, index], user, ["text"], None, 401), - (index, user, ["text", "title"], None, 401), - (index, user, ["title"], None, 200), - (index, reader, ["text"], None, 200), - ([index_docs, index], reader, ["text"], None, 200), - (index, reader, ["title"], ["text"], 200) - ]: - snippets_param = ("&snippets" + get_join(snippets)) if snippets else "" - check( - client.get( - f"/index/{get_join(ix)}/documents?fields={get_join(fields)}{snippets_param}", - headers=build_headers(u), - ), - outcome, - msg=f"Index: {ix}, user: {u}, fields: {fields}", - ) - body = {"fields": fields} - if snippets: - body["snippets"] = snippets - check( - client.post( - f"/index/{get_join(ix)}/query", - headers=build_headers(u), - json=body, - ), - outcome, - msg=f"Index: {ix}, user: {u}, fields: {fields}", - ) diff --git a/tests/test_api_metareader.py b/tests/test_api_metareader.py new file mode 100644 index 0000000..47927f6 --- /dev/null +++ b/tests/test_api_metareader.py @@ -0,0 +1,102 @@ +from fastapi.testclient import TestClient + +from tests.tools import get_json, build_headers, post_json + + +def create_index_metareader(client, index, admin): + # Create new user and set index role to metareader + client.post( + f"/users", + headers=build_headers(admin), + json={"email": "meta@reader.com", "role": "METAREADER"}, + ), + client.put( + f"/index/{index}/users/meta@reader.com", + headers=build_headers(admin), + json={"role": "METAREADER"}, + ), + + +def set_metareader_access(client, index, admin, access): + client.post( + f"/index/{index}/fields", + headers=build_headers(admin), + json={"text": {"type": "text", "meta": {"metareader_access": access}}}, + ) + + +def check_allowed(client, index, field=None, snippet=None, allowed=True): + params = {} + body = {} + + if field: + params["fields"] = field + body["fields"] = [field] + if snippet: + params["snippets"] = snippet + body["snippets"] = [snippet] + + get_json( + client, + f"/index/{index}/documents", + user="meta@reader.com", + expected=200 if allowed else 401, + params=params, + ) + post_json( + client, + f"/index/{index}/query", + user="meta@reader.com", + expected=200 if allowed else 401, + json=body, + ) + + +def test_metareader_none(client: TestClient, admin, index_docs): + """ + Set text field to metareader_access=none + Metareader should not be able to get field both full and as snippet + """ + create_index_metareader(client, index_docs, admin) + set_metareader_access(client, index_docs, admin, "none") + check_allowed(client, index_docs, field="text", allowed=False) + check_allowed(client, index_docs, snippet="text", allowed=False) + + +def test_metareader_read(client: TestClient, admin, index_docs): + """ + Set text field to metareader_access=read + Metareader should be able to get field both full and as snippet + """ + create_index_metareader(client, index_docs, admin) + set_metareader_access(client, index_docs, admin, "read") + check_allowed(client, index_docs, field="text", allowed=True) + check_allowed(client, index_docs, snippet="text", allowed=True) + + +def test_metareader_snippet(client: TestClient, admin, index_docs): + """ + Set text field to metareader_access=snippet + Meta reader should be able to get field as snippet, but not full + """ + create_index_metareader(client, index_docs, admin) + + set_metareader_access(client, index_docs, admin, "snippet") + check_allowed(client, index_docs, field="text", allowed=False) + check_allowed(client, index_docs, snippet="text", allowed=True) + + +def test_metareader_snippet_params(client: TestClient, admin, index_docs): + """ + Set text field to metareader_access=snippet[50;1;20] + Metareader should only be able to get field as snippet + with maximum parameters of nomatch_chars=50, max_matches=1, match_chars=20 + """ + create_index_metareader(client, index_docs, admin) + + set_metareader_access(client, index_docs, admin, "snippet[50;1;20]") + check_allowed(client, index_docs, field="text", allowed=False) + check_allowed(client, index_docs, snippet="text", allowed=False) + check_allowed(client, index_docs, snippet="text[51;1;20]", allowed=False) + check_allowed(client, index_docs, snippet="text[50,1,20]", allowed=True) + check_allowed(client, index_docs, snippet="text[49;1;20]", allowed=True) diff --git a/tests/test_api_pagination.py b/tests/test_api_pagination.py index 62c524d..8139013 100644 --- a/tests/test_api_pagination.py +++ b/tests/test_api_pagination.py @@ -5,7 +5,7 @@ def test_pagination(client, index, user): """Does basic pagination work?""" - set_role(index, user, Role.METAREADER) + set_role(index, user, Role.READER) upload(index, docs=[{"i": i} for i in range(66)]) url = f"/index/{index}/documents" @@ -42,7 +42,7 @@ def test_pagination(client, index, user): def test_scroll(client, index, user): - set_role(index, user, Role.METAREADER) + set_role(index, user, Role.READER) upload(index, docs=[{"i": i} for i in range(66)]) url = f"/index/{index}/documents" r = get_json( diff --git a/tests/test_api_query.py b/tests/test_api_query.py index 46d25f0..eb51992 100644 --- a/tests/test_api_query.py +++ b/tests/test_api_query.py @@ -48,7 +48,8 @@ def q(**query_string): def qi(**query_string): return {int(doc["_id"]) for doc in q(**query_string)} - # TODO: check auth + # TODO make sure all auth is checked in test_api_query_auth + # Query strings assert qi(q="text") == {0, 1} assert qi(q="test*") == {1, 2, 3} diff --git a/tests/test_query.py b/tests/test_query.py index 3491c49..3669613 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -8,9 +8,9 @@ def query_ids(index: str, q: Optional[str] = None, **kwargs) -> Set[int]: if q is not None: - kwargs['queries'] = [q] + kwargs["queries"] = [q] res = query.query_documents(index, **kwargs) - return {int(h['_id']) for h in res.data} + return {int(h["_id"]) for h in res.data} def test_query(index_docs): @@ -43,18 +43,23 @@ def test_highlight(index): words = "The error of regarding functional notions is not quite equivalent to" text = f"{words} a test document. {words} other text documents. {words} you!" upload(index, [dict(title="Een test titel", text=text)]) - res = query.query_documents(index, queries=["te*"], highlight=True) + res = query.query_documents( + index, fields=["title", "text"], queries=["te*"], highlight=True + ) doc = res.data[0] - assert doc['title'] == "Een test titel" - assert doc['text'] == f"{words} a test document. {words} other text documents. {words} you!" - - doc = query.query_documents(index, queries=["te*"], highlight={"number_of_fragments": 1}).data[0] - assert doc['title'] == "Een test titel" - assert " a test" in doc['text'] - assert "..." not in doc['text'] - - doc = query.query_documents(index, queries=["te*"], highlight={"number_of_fragments": 2}).data[0] - assert re.search(r" a test[^<]*...[^<]*other text documents", doc['text']) + assert doc["title"] == "Een test titel" + assert ( + doc["text"] + == f"{words} a test document. {words} other text documents. {words} you!" + ) + + # snippets can also have highlights + doc = query.query_documents( + index, queries=["te*"], fields=["title"], snippets=["text"], highlight=True + ).data[0] + assert doc["title"] == "Een test titel" + assert " a test" in doc["text"] + assert " ... " in doc["text"] def test_query_multiple_index(index_docs, index): From f2bc7a3fc5616c274e1ba07e90b9981f89070318 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Wed, 10 Jan 2024 09:34:30 +0100 Subject: [PATCH 07/80] sstuff and vscode settings --- .dmypy.json | 1 + .gitignore | 5 +- .vscode/extensions.json | 7 +++ .vscode/settings.json | 9 +++ amcat4/api/auth.py | 95 +++++++++++++---------------- amcat4/api/index.py | 16 +++-- amcat4/api/query.py | 113 ++++++++++++++++++++++------------- amcat4/elastic.py | 102 +++++++++++-------------------- amcat4/index.py | 83 ++++++++++++++++++++----- amcat4/models.py | 33 ++++++++++ amcat4/query.py | 75 ++++++++++------------- amcat4/util.py | 21 ++++--- tests/test_api_metareader.py | 28 ++------- tests/test_api_query.py | 69 ++++++++------------- 14 files changed, 351 insertions(+), 306 deletions(-) create mode 100644 .dmypy.json create mode 100644 .vscode/extensions.json create mode 100644 .vscode/settings.json create mode 100644 amcat4/models.py diff --git a/.dmypy.json b/.dmypy.json new file mode 100644 index 0000000..f9027c8 --- /dev/null +++ b/.dmypy.json @@ -0,0 +1 @@ +{"pid": 208431, "connection_name": "/tmp/tmpccrhv4cx/dmypy.sock"} diff --git a/.gitignore b/.gitignore index 82bb2ac..0c9bf0f 100644 --- a/.gitignore +++ b/.gitignore @@ -50,7 +50,10 @@ nosetests.xml .idea # vscode meuk -.vscode +#.vscode +.vscode/* +!.vscode/settings.json +!.vscode/extensions.json # static files navigator/media/static diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 0000000..21309ef --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,7 @@ +{ + "recommendations": [ + "matangover.mypy", + "ms-python.python", + "ms-python.black-formatter" + ] +} diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..862f43b --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,9 @@ +{ + "[python]": { + "editor.defaultFormatter": "ms-python.black-formatter", + "editor.formatOnSave": true + }, + "black-formatter.args": ["--line-length", "127"], + "mypy.enabled": true, + "mypy.runUsingActiveInterpreter": true +} diff --git a/amcat4/api/auth.py b/amcat4/api/auth.py index f4ecb5c..f2e542a 100644 --- a/amcat4/api/auth.py +++ b/amcat4/api/auth.py @@ -13,7 +13,7 @@ from starlette.status import HTTP_401_UNAUTHORIZED from amcat4 import elastic -from amcat4.util import parse_snippet +from amcat4.util import parse_field from amcat4.config import get_settings, AuthOptions from amcat4.index import Role, get_role, get_global_role @@ -120,9 +120,7 @@ def check_role( ) -def check_query_allowed( - index: str, user: str, fields: Iterable[str] = None, snippets: Iterable[str] = None -) -> None: +def check_fields_access(index: str, user: str, fields: Iterable[str]) -> None: """Check if the given user is allowed to query the given fields and snippets on the given index. :param index: The index to check the role on @@ -131,6 +129,7 @@ def check_query_allowed( :param snippets: The snippets to check :return: Nothing. Throws HTTPException if the user is not allowed to query the given fields and snippets. """ + role = get_role(index, user) if role is None: raise HTTPException( @@ -138,67 +137,53 @@ def check_query_allowed( detail=f"User {user} does not have a role on index {index}", ) if role >= Role.READER: - return True + return None + if fields is None: + return None # after this, we know the user is a metareader, so we need to check metareader_access - - def check_fields_access(fields, index_fields) -> None: - if fields is None: - return None - - for field in fields: - if field not in index_fields.keys(): - continue - field_meta = index_fields[field].get("meta", {}) - metareader_access = field_meta.get("metareader_access", None) - if metareader_access != "read": + index_fields = elastic.get_fields(index) + for field in fields: + fieldname, nomatch_chars, max_matches, match_chars = parse_field(field) + if fieldname not in index_fields.keys(): + continue + meta = index_fields[fieldname].get("meta", {}) + metareader_access = meta.get("metareader_access", None) + if not metareader_access or metareader_access == "none": + raise HTTPException( + status_code=401, + detail=f"METAREADER cannot read {field} on index {index}", + ) + if metareader_access == "read": + continue + if metareader_access.startswith("snippet"): + ( + _, + allowed_nomatch_chars, + allowed_max_matches, + allowed_match_chars, + ) = parse_field(metareader_access) + max_params = f"{fieldname}[{allowed_nomatch_chars};{allowed_max_matches};{allowed_match_chars}]" + + if nomatch_chars is None: raise HTTPException( status_code=401, - detail=f"METAREADER cannot read {field} on index {index}", + detail=f"METAREADER cannot read {field} on index {index}. " + f"Can only read snippets with parameters: {max_params}", ) - def check_snippets_access(snippets, index_fields) -> None: - if snippets is None: - return None - - for snippet in snippets: - field, nomatch_chars, max_matches, match_chars = parse_snippet(snippet) - if field not in index_fields.keys(): - continue - - field_meta = index_fields[field].get("meta", {}) - metareader_access = field_meta.get("metareader_access", "none") - - if metareader_access == "read": - continue - elif "snippet" in metareader_access: - ( - _, - meta_nomatch_chars, - meta_max_matches, - meta_match_chars, - ) = parse_snippet(metareader_access) - valid_nomatch_chars = nomatch_chars <= meta_nomatch_chars - valid_max_matches = max_matches <= meta_max_matches - valid_match_chars = match_chars <= meta_match_chars - valid = valid_nomatch_chars and valid_max_matches and valid_match_chars - if not valid: - max_params = f"{field}[{meta_nomatch_chars};{meta_max_matches};{meta_match_chars}]" - raise HTTPException( - status_code=401, - detail=f"The requested snippet of {field} on index {index} is too long. " - f"max parameters are: {max_params}", - ) - else: + valid_nomatch_chars = nomatch_chars <= allowed_nomatch_chars + valid_max_matches = max_matches <= allowed_max_matches + valid_match_chars = match_chars <= allowed_match_chars + + valid = valid_nomatch_chars and valid_max_matches and valid_match_chars + if not valid: raise HTTPException( status_code=401, - detail=f"METAREADER cannot read snippet of {field} on index {index}", + detail=f"The requested snippet of {fieldname} on index {index} is too long. " + f"max parameters are: {max_params}", ) - index_fields = elastic.get_fields(index) - check_fields_access(fields, index_fields) - check_snippets_access(snippets, index_fields) - async def authenticated_user(token: str = Depends(oauth2_scheme)) -> str: """Dependency to verify and return a user based on a token.""" diff --git a/amcat4/api/index.py b/amcat4/api/index.py index 3fb5584..b1c22fc 100644 --- a/amcat4/api/index.py +++ b/amcat4/api/index.py @@ -272,8 +272,12 @@ def get_fields(ix: str, user=Depends(authenticated_user)): Returns a json array of {name, type} objects """ check_role(user, Role.METAREADER, ix) - indices = ix.split(",") - return elastic.get_fields(indices) + if "," in ix: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"/[index]/fields does not support multiple indices", + ) + return elastic.get_fields(ix) @app_index.post("/{ix}/fields") @@ -291,11 +295,11 @@ def set_fields( @app_index.get("/{ix}/fields/{field}/values") -def get_field_values(ix: str, field: str, user: str =Depends(authenticated_user)): +def get_field_values(ix: str, field: str, user: str = Depends(authenticated_user)): """ Get unique values for a specific field. Should mainly/only be used for tag fields. Main purpose is to provide a list of values for a dropdown menu. - + TODO: at the moment 'only' returns top 2000 values. Currently throws an error if there are more than 2000 unique values. We can increase this limit, but there should be a limit. Querying could be an option, but not sure if that is @@ -310,12 +314,14 @@ def get_field_values(ix: str, field: str, user: str =Depends(authenticated_user) ) return values + @app_index.get("/{ix}/fields/{field}/stats") -def get_field_stats(ix: str, field: str, user: str =Depends(authenticated_user)): +def get_field_stats(ix: str, field: str, user: str = Depends(authenticated_user)): """Get statistics for a specific value. Only works for numeric (incl date) fields.""" check_role(user, Role.READER, ix) return elastic.get_field_stats(ix, field) + @app_index.get("/{ix}/users") def list_index_users(ix: str, user: str = Depends(authenticated_user)): """ diff --git a/amcat4/api/query.py b/amcat4/api/query.py index 0c7e4d4..fba7a9d 100644 --- a/amcat4/api/query.py +++ b/amcat4/api/query.py @@ -8,10 +8,10 @@ from amcat4 import elastic, query, aggregate from amcat4.aggregate import Axis, Aggregation -from amcat4.api.auth import authenticated_user, check_query_allowed -from amcat4.index import Role +from amcat4.api.auth import authenticated_user, check_fields_access +from amcat4.index import Role, get_role from amcat4.query import update_tag_query -from amcat4.util import parse_snippet +from amcat4.util import parse_field app_query = APIRouter(prefix="/index", tags=["query"]) @@ -33,6 +33,59 @@ class QueryResult(BaseModel): meta: QueryMeta +def get_or_validate_allowed_fields( + user: str, indices: Iterable[str], fields: Iterable[str] = None +): + """ + For any endpoint that returns field values, make sure the user only gets fields that + they are allowed to see. If fields is None, return all allowed fields. If fields is not None, + check whether the user can access the fields (If not, raise an error). + """ + + if not isinstance(user, str): + raise ValueError("User should be a string") + if not isinstance(indices, list): + raise ValueError("Indices should be a list") + if fields is not None and not isinstance(fields, list): + raise ValueError("Fields should be a list or None") + + if fields is None: + if len(indices) > 1: + # this restrictions is needed, because otherwise we need to return all allowed fields taking + # into account the user's role for each index, and take the lowest possible access. + # this is error prone and complex, so best to just disallow it. Also, requesting all fields + # for multiple indices is probably not something we should support anyway + raise ValueError("Fields should be specified if multiple indices are given") + index_fields = elastic.get_fields(indices[0]) + role = get_role(indices[0], user) + all_fields = [] + for field in index_fields.keys(): + if role >= Role.READER: + all_fields.append(field) + elif role == Role.METAREADER: + field_meta: dict = index_fields[field].get("meta", {}) + metareader_access = field_meta.get("metareader_access", None) + if metareader_access == "read": + all_fields.append(field) + elif "snippet" in metareader_access: + _, nomatch_chars, max_matches, match_chars = parse_field( + metareader_access + ) + all_fields.append( + f"{field}[{nomatch_chars};{max_matches};{match_chars}]" + ) + else: + raise HTTPException( + status_code=401, + detail=f"User {user} does not have a role on index {indices[0]}", + ) + + else: + for index in indices: + check_fields_access(index, user, fields) + return fields + + @app_query.get("/{index}/documents", response_model=QueryResult) def get_documents( index: str, @@ -50,20 +103,12 @@ def get_documents( ), fields: str = Query( None, - description="Comma separated list of fields to return", - pattern=r"\w+(,\w+)*", - ), - snippets: str = Query( - None, - description="Comma separated list of fields to return as snippets. If only field names are given, the default " - "snippet parameters are used. The parameters are 'nomatch_chars' (default: 150), 'max_matches' (default: 3) " - "and 'match_chars' (default: 50). If there is no query, the snippet is the first [nomatch_chars] characters. " + description="Comma separated list of fields to return. " + "You can also request a snippet of a field by appending the suffix [nomatch_chars;max_matches;match_chars]. " + "'matches' here refers to words from text queries. If there is no query, the snippet is the first [nomatch_chars] characters. " "If there is a query, snippets are returned for up to [max_matches] matches, with each match having [match_chars] " - "characters. match snippets are concatenated with ' ... ' and have tags around the matched text. " - "If you want to use custom snippet parameters, you can add a suffix to the field name with the parameters between " - "brackets, in the format: fieldname[nomatch_chars;max_matches;match_chars] (e.g, text[150;3;50]). " - "(always provide all 3 parameters, even if you only want to change one)", - pattern=r"[\w\[;\]]+(,[\w\[;\]]+)*", + "characters. If there are multiple matches, they are concatenated with ' ... '.", + pattern=r"\w+(,\w+)*", ), highlight: bool = Query(False, description="If true, highlight fields"), per_page: int = Query(None, description="Number of results per page"), @@ -89,12 +134,9 @@ def get_documents( """ indices = index.split(",") fields = fields and fields.split(",") - snippets = snippets and snippets.split(",") - if not fields and not snippets: - fields = ["date", "title", "url"] - + fields = get_or_validate_allowed_fields(user, indices, fields) for index in indices: - check_query_allowed(index, user, fields, snippets) + check_fields_access(index, user, fields) args = {} sort = sort and [ @@ -185,18 +227,13 @@ def query_documents_post( "or a dict of {'label': 'query'}", ), fields: Optional[List[str]] = Body( - None, description="List of fields to retrieve for each document" - ), - snippets: Optional[List[str]] = Body( None, - description="Fields to retrieve as snippets. If only field names are given, the default " - "snippet parameters are used. The parameters are [nomatch_chars] (default: 200), [max_matches] (default: 3) " - "and [match_chars] (default: 50). If there is no query, the snippet is the first [nomatch_chars] characters. " + description="List of fields to retrieve for each document" + "You can also request a snippet of a field by adding snippet parameters between brackets: " + "fieldname[nomatch_chars;max_matches;match_chars]. 'matches' here refers to words from text queries. " + "If there is no query, the snippet is the first [nomatch_chars] characters. " "If there is a query, snippets are returned for up to [max_matches] matches, with each match having [match_chars] " - "characters. match snippets also have tags around the matched text. " - "If you want to use custom snippet parameters, you can add a suffix to the field name with the parameters between " - "brackets, in the format: fieldname[nomatch_chars;max_matches;match_chars] (e.g, text[150;3;50]). " - "(always provide all 3 parameters, even if you only want to change one)", + "characters. If there are multiple matches, they are concatenated with ' ... '.", ), filters: Optional[ Dict[str, Union[FilterValue, List[FilterValue], FilterSpec]] @@ -244,17 +281,8 @@ def query_documents_post( # TODO check user rights on index # Standardize fields, queries and filters to their most versatile format indices = index.split(",") - if fields: - if isinstance(fields, str): - fields = [fields] - if snippets: - if isinstance(snippets, str): - snippets = [snippets] - if not fields and not snippets: - fields = ["date", "title", "url"] - - for index in indices: - check_query_allowed(index, user, fields, snippets) + fields = [fields] if isinstance(fields, str) else fields + fields = get_or_validate_allowed_fields(user, indices, fields) queries = _process_queries(queries) filters = dict(_process_filters(filters)) @@ -263,7 +291,6 @@ def query_documents_post( queries=queries, filters=filters, fields=fields, - snippets=snippets, sort=sort, per_page=per_page, page=page, diff --git a/amcat4/elastic.py b/amcat4/elastic.py index e81668d..f6e7d9b 100644 --- a/amcat4/elastic.py +++ b/amcat4/elastic.py @@ -18,9 +18,9 @@ from elasticsearch import Elasticsearch, NotFoundError from elasticsearch.helpers import bulk +from amcat4.util import parse_field from amcat4.config import get_settings -from amcat4.util import parse_snippet SYSTEM_INDEX_VERSION = 1 @@ -219,16 +219,28 @@ def es_actions(index, documents): bulk(es(), actions) -def get_field_mapping(type_: Union[str, dict]): - if isinstance(type_, str): - return ES_MAPPINGS[type_] +def get_field_mapping(current: dict, update: Union[str, dict]): + if isinstance(update, str): + type = update + newmeta = None else: - mapping = ES_MAPPINGS[type_["type"]] - meta = mapping.get("meta", {}) - if m := type_.get("meta"): - meta.update(m) - mapping["meta"] = validate_field_meta(meta) - return mapping + if "type" not in update: + type = current.get("type") + if type is None: + raise ValueError("Field type is not specified") + type = update.get("type") + if type not in ES_MAPPINGS: + raise ValueError(f"Invalid field type: {type}") + newmeta = update.get("meta", None) + + if "meta" in current: + meta = current["meta"] + else: + meta = ES_MAPPINGS[type].get("meta", {}) + if newmeta: + meta.update(newmeta) + + return dict(type=type, meta=meta) def validate_field_meta(meta: dict): @@ -258,21 +270,25 @@ def validate_field_meta(meta: dict): # metareader_access can be "none", "read", or "snippet" # if snippet, can also include the maximum snippet parameters (nomatch_chars, max_matches, match_chars) # in the format: snippet[nomatch_chars;max_matches;match_chars] - reg = r"^(read|none|snippet(\[\d+;\d+;\d+\])?)$" + reg = r"^(read|none|snippet\[\d+;\d+;\d+\])$" if not re.match(reg, meta[meta_field]): raise ValueError(f"Invalid metareader_access value: {meta[meta_field]}") return meta -def set_fields(index: str, fields: Mapping[str, str]): +def set_fields(index: str, fields: Mapping[str, Union[str, dict]]): """ Update the column types for this index :param index: The name of the index (without prefix) :param fields: A mapping of field:type for column types """ - properties = {field: get_field_mapping(type_) for (field, type_) in fields.items()} + index_fields = get_index_fields(index) + properties = { + field: get_field_mapping(index_fields.get(field, {}), update) + for (field, update) in fields.items() + } es().indices.put_mapping(index=index, properties=properties) @@ -322,6 +338,7 @@ def _get_type_from_property(properties: dict) -> str: def _get_fields(index: str) -> Iterable[Tuple[str, dict]]: r = es().indices.get_mapping(index=index) + for k, v in r[index]["mappings"]["properties"].items(): t = dict(name=k, type=_get_type_from_property(v)) if meta := v.get("meta"): @@ -338,65 +355,16 @@ def get_index_fields(index: str) -> Mapping[str, dict]: return dict(_get_fields(index)) -def get_fields(index: Union[str, Sequence[str]]) -> Mapping[str, dict]: +def get_fields(index: str) -> Mapping[str, dict]: """ Get the field types in use in this index or indices :param index: name(s) of index(es) to query :return: a dict of fieldname: field objects {fieldname: {name, type, ...}] """ - if isinstance(index, str): - return get_index_fields(index) - - # def get_meta_value(field, meta_key, default): - # return field.get("meta", {}).get(meta_key) or default - - # def get_least_metareader_access(access1, access2): - # if (access1 == None) or (access2 == None): - # return None - - # if "snippet" in access1 and access2 == "read": - # return access1 - - # if "snippet" in access2 and access1 == "read": - # return access2 - - # if "snippet" in access1 and "snippet" in access2: - # _, nomatch_chars1, max_matches1, match_chars1 = parse_snippet(access1) - # _, nomatch_chars2, max_matches2, match_chars2 = parse_snippet(access2) - # nomatch_chars = min(nomatch_chars1, nomatch_chars2) - # max_matches = min(max_matches1, max_matches2) - # match_chars = match_chars1 + match_chars2 - # return f"snippet[{nomatch_chars},{max_matches},{match_chars}]" - - # if access1 == "read" and access2 == "read": - # return "read" - - result = {} - for ix in index: - for f, ftype in get_index_fields(ix).items(): - if f in result: - if result[f] != ftype: - # note that for merged fields metareader access is always None - # metareader_access_1: bool = get_meta_value( - # result[f], "metareader_visible", None - # ) - # metareader_access_2: bool = get_meta_value( - # ftype, "metareader_visible", None - # ) - # metareader_access = get_least_metareader_access( - # metareader_access_1, metareader_access_2 - # ) - - result[f] = { - "name": f, - "type": "keyword", - "meta": { - "merged": True, - }, - } - else: - result[f] = ftype - return result + if not isinstance(index, str): + raise ValueError("get_fields only supports a single index") + + return get_index_fields(index) def get_field_values(index: str, field: str, size: int) -> List[str]: diff --git a/amcat4/index.py b/amcat4/index.py index 440b8fa..cff55ce 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -38,6 +38,7 @@ from amcat4.config import get_settings from amcat4.elastic import DEFAULT_MAPPING, es, get_fields +from amcat4.models import FieldSettings, updateFieldSettings class Role(IntEnum): @@ -51,7 +52,8 @@ class Role(IntEnum): GLOBAL_ROLES = "_global" Index = collections.namedtuple( - "Index", ["id", "name", "description", "guest_role", "roles", "summary_field"] + "Index", + ["id", "name", "description", "guest_role", "roles", "summary_field"], ) @@ -84,14 +86,8 @@ def list_known_indices(email: str = None) -> Iterable[Index]: # "must_not": {"term": {"guest_role": {"value": "none", "case_insensitive": True}}}}} # q_role = {"nested": {"path": "roles", "query": {"term": {"roles.email": email}}}} # query = {"bool": {"should": [q_guest, q_role]}} - check_role = not ( - email is None - or get_global_role(email) == Role.ADMIN - or get_settings().auth == "no_auth" - ) - for index in elasticsearch.helpers.scan( - es(), index=get_settings().system_index, fields=[], _source=True - ): + check_role = not (email is None or get_global_role(email) == Role.ADMIN or get_settings().auth == "no_auth") + for index in elasticsearch.helpers.scan(es(), index=get_settings().system_index, fields=[], _source=True): ix = _index_from_elastic(index) if ix.name == GLOBAL_ROLES: continue @@ -132,7 +128,11 @@ def create_index( """ es().indices.create(index=index, mappings={"properties": DEFAULT_MAPPING}) register_index( - index, guest_role=guest_role, name=name, description=description, admin=admin + index, + guest_role=guest_role, + name=name, + description=description, + admin=admin, ) @@ -210,7 +210,7 @@ def set_role(index: str, email: str, role: Optional[Role]): try: d = es().get(index=system_index, id=index, source_includes="roles") except NotFoundError: - raise ValueError(f"Index {index} does is not registered") + raise ValueError(f"Index {index} is not registered") roles_dict = _roles_from_elastic(d["_source"].get("roles", [])) if role: roles_dict[email] = role @@ -219,7 +219,9 @@ def set_role(index: str, email: str, role: Optional[Role]): return # Nothing to change del roles_dict[email] es().update( - index=system_index, id=index, doc=dict(roles=_roles_to_elastic(roles_dict)) + index=system_index, + id=index, + doc=dict(roles=_roles_to_elastic(roles_dict)), ) @@ -237,6 +239,38 @@ def set_guest_role(index: str, guest_role: Optional[Role]): modify_index(index, guest_role=guest_role, remove_guest_role=(guest_role is None)) +def _fields_settings_to_elastic(fields_settings: Dict[str, FieldSettings]) -> List[Dict]: + return [{"field": field, "settings": settings} for field, settings in fields_settings.items()] + + +def _fields_settings_from_elastic( + fields_settings: List[Dict], +) -> Dict[str, FieldSettings]: + return {fs["field"]: fs["settings"] for fs in fields_settings} + + +def set_fields_settings(index: str, new_fields_settings: Dict[str, FieldSettings]): + """ + Set the fields settings for this index + """ + system_index = get_settings().system_index + try: + d = es().get(index=system_index, id=index, source_includes="fields_settings") + except NotFoundError: + raise ValueError(f"Index {index} is not registered") + fields_settings = _fields_settings_from_elastic(d["_source"].get("fields_settings", {})) + + for field, new_settings in new_fields_settings.items(): + current: FieldSettings = fields_settings.get(field, FieldSettings()) + fields_settings[field] = updateFieldSettings(current, new_settings) + + es().update( + index=system_index, + id=index, + doc=dict(roles=_fields_settings_to_elastic(fields_settings)), + ) + + def modify_index( index: str, name: Optional[str] = None, @@ -256,9 +290,7 @@ def modify_index( if summary_field not in f: raise ValueError(f"Summary field {summary_field} does not exist!") if f[summary_field]["type"] not in ["date", "keyword", "tag"]: - raise ValueError( - f"Summary field {summary_field} should be date, keyword or tag, not {f[summary_field]['type']}!" - ) + raise ValueError(f"Summary field {summary_field} should be date, keyword or tag, not {f[summary_field]['type']}!") doc = {x: v for (x, v) in doc.items() if v} if remove_guest_role: doc["guest_role"] = None @@ -302,6 +334,7 @@ def get_role(index: str, email: str) -> Optional[Role]: role = doc["_source"].get("guest_role", None) if role and role.lower() != "none": return Role[role] + return None def get_guest_role(index: str) -> Optional[Role]: @@ -311,13 +344,16 @@ def get_guest_role(index: str) -> Optional[Role]: """ try: d = es().get( - index=get_settings().system_index, id=index, source_includes="guest_role" + index=get_settings().system_index, + id=index, + source_includes="guest_role", ) except NotFoundError: raise IndexDoesNotExist(index) role = d["_source"].get("guest_role") if role and role.lower() != "none": return Role[role] + return None def get_global_role(email: str, only_es: bool = False) -> Optional[Role]: @@ -333,6 +369,21 @@ def get_global_role(email: str, only_es: bool = False) -> Optional[Role]: return get_role(index=GLOBAL_ROLES, email=email) +def get_fields_settings(index: str) -> Dict[str, FieldSettings]: + """ + Retrieve the fields settings for this index + """ + try: + d = es().get( + index=get_settings().system_index, + id=index, + source_includes="fields_settings", + ) + except NotFoundError: + raise IndexDoesNotExist(index) + return _fields_settings_from_elastic(d["_source"].get("fields_settings", {})) + + def list_users(index: str) -> Dict[str, Role]: """ " List all users and their roles on the given index diff --git a/amcat4/models.py b/amcat4/models.py new file mode 100644 index 0000000..aed6c68 --- /dev/null +++ b/amcat4/models.py @@ -0,0 +1,33 @@ +from pydantic import BaseModel +from typing import Optional + + +class SnippetParams(BaseModel): + """ + Snippet parameters for a specific field. + nomatch_chars is the number of characters to show if there is no query match. This is always + the first [nomatch_chars] of the field. + """ + + nomatch_chars: int + max_matches: int + match_chars: int + + +class FieldMetareaderAccess(BaseModel): + """Metareader access for a specific field.""" + + access: bool + max_snippet: Optional[SnippetParams] + + +class FieldSettings(BaseModel): + """Settings for a field.""" + + metareader_access: Optional[FieldMetareaderAccess] = None + + +def updateFieldSettings(field: FieldSettings, update: FieldSettings): + for key in field.model_fields_set: + setattr(field, key, getattr(update, key)) + return field diff --git a/amcat4/query.py b/amcat4/query.py index 2de0ca1..71e8e33 100644 --- a/amcat4/query.py +++ b/amcat4/query.py @@ -18,7 +18,9 @@ from .date_mappings import mappings from .elastic import es, update_tag_by_query -from amcat4.util import parse_snippet +from amcat4 import elastic +from amcat4.util import parse_field +from amcat4.index import Role, get_role def build_body( @@ -164,7 +166,6 @@ def query_documents( specify the time the context should be kept alive, or True to get the default of 2m. :param scroll_id: if not None, should be a previously returned context_id to retrieve a new page of results :param fields: if not None, specify a list of fields to retrieve for each hit - :param snippets: if not None, specify a list of fields to retrieve snippets for :param filters: if not None, a dict of filters with either value, values, or gte/gt/lte/lt ranges: {field: {'values': [value1,value2], 'value': value, @@ -179,10 +180,6 @@ def query_documents( """ if fields is not None and not isinstance(fields, list): raise ValueError("fields should be a list") - if snippets is not None and not isinstance(snippets, list): - raise ValueError("snippets should be a list") - if overlap_fields_snippets(fields, snippets): - raise ValueError("Cannot request a field AND it's snippet at the same time") if scroll or scroll_id: # set scroll to default also if scroll_id is given but no scroll time is known @@ -195,7 +192,7 @@ def query_documents( if not result["hits"]["hits"]: return None else: - h = query_highlight(fields, highlight, snippets) + h = query_highlight(fields, highlight) body = build_body(queries.values(), filters, h) if fields: @@ -231,35 +228,44 @@ def query_documents( ) -def query_highlight(fields: Iterable[str], highlight: bool, snippets: Iterable[str]): +def query_highlight(fields: Iterable[str] = None, highlight_queries: bool = False): """ The elastic "highlight" parameters works for both highlighting text fields and adding snippets. This function will return the highlight parameter to be added to the query body. """ - if highlight is False and snippets is None: - return None highlight = { - "pre_tags": [""] if highlight is True else [""], - "post_tags": [""] if highlight is True else [""], + # "pre_tags": [""] if highlight is True else [""], + # "post_tags": [""] if highlight is True else [""], "require_field_match": True, - "fields": {}, } - if fields is not None: + if fields is None: + if highlight_queries is True: + highlight["fields"]["*"] = {"number_of_fragments": 0} + else: + highlight["fields"] = {} for field in fields: - highlight["fields"][field] = {"number_of_fragments": 0} - - if snippets is not None: - # TODO: get index meta data to see which snippets are allowed and what - # the nr and size should be - for snippet in snippets: - field, nomatch_chars, max_matches, match_chars = parse_snippet(snippet) - highlight["fields"][field] = { - "no_match_size": nomatch_chars, - "number_of_fragments": max_matches, - "fragment_size": match_chars, - } + fieldname, nomatch_chars, max_matches, match_chars = parse_field(field) + if nomatch_chars is None: + if highlight_queries is True: + # This will overwrite the field with the highlighted version, so + # only needed if highlight is True + highlight["fields"][fieldname] = {"number_of_fragments": 0} + else: + # the elastic highlight feature is also used to get snippets. note that + # above in the + highlight["fields"][fieldname] = { + "no_match_size": nomatch_chars, + "number_of_fragments": max_matches, + "fragment_size": match_chars, + } + if highlight_queries is False or max_matches == 0: + # This overwrites the actual query, so that the highlights are not returned. + # Also used to get the nomatch snippet if max_matches = 0 + highlight["fields"][fieldname]["highlight_query"] = { + "match_all": {} + } return highlight @@ -289,20 +295,3 @@ def update_tag_query( """Add or remove tags using a query""" body = build_body(queries and queries.values(), filters, ids=ids) update_tag_by_query(index, action, body, field, tag) - - -def overlap_fields_snippets( - fields: Iterable[str] = None, snippets: Iterable[str] = None -) -> bool: - """ - If both fields and snippets are requested as output, check if there are any overlaps - """ - if fields is None or snippets is None: - return False - - for snippet in snippets: - field, _, _, _ = parse_snippet(snippet) - if field in fields: - return True - - return False diff --git a/amcat4/util.py b/amcat4/util.py index 39ce16d..fbefe23 100644 --- a/amcat4/util.py +++ b/amcat4/util.py @@ -2,25 +2,24 @@ from typing import Tuple -def parse_snippet(snippet: str) -> Tuple[str, int, int, int]: +def parse_field(field: str) -> Tuple[str, int, int, int]: """ - Parse a snippet string into a field and the snippet parameters. + Parse a field into a field and the snippet parameters. The format is fieldname[nomatch_chars;max_matches;match_chars]. - If the snippet does not contain parameters (or the specification is wrong), - we assume the snippet is just the field name and use default values. + If no snippet parameters are given, the values are None """ pattern = r"\[([0-9]+);([0-9]+);([0-9]+)]$" - match = re.search(pattern, snippet) + match = re.search(pattern, field) if match: - field = snippet[: match.start()] + fieldname = field[: match.start()] nomatch_chars = int(match.group(1)) max_matches = int(match.group(2)) match_chars = int(match.group(3)) else: - field = snippet - nomatch_chars = 200 - max_matches = 3 - match_chars = 50 + fieldname = field + nomatch_chars = None + max_matches = None + match_chars = None - return field, nomatch_chars, max_matches, match_chars + return fieldname, nomatch_chars, max_matches, match_chars diff --git a/tests/test_api_metareader.py b/tests/test_api_metareader.py index 47927f6..54f1e18 100644 --- a/tests/test_api_metareader.py +++ b/tests/test_api_metareader.py @@ -25,16 +25,13 @@ def set_metareader_access(client, index, admin, access): ) -def check_allowed(client, index, field=None, snippet=None, allowed=True): +def check_allowed(client, index, field=None, allowed=True): params = {} body = {} if field: params["fields"] = field body["fields"] = [field] - if snippet: - params["snippets"] = snippet - body["snippets"] = [snippet] get_json( client, @@ -60,7 +57,7 @@ def test_metareader_none(client: TestClient, admin, index_docs): create_index_metareader(client, index_docs, admin) set_metareader_access(client, index_docs, admin, "none") check_allowed(client, index_docs, field="text", allowed=False) - check_allowed(client, index_docs, snippet="text", allowed=False) + check_allowed(client, index_docs, field="text[150;3;50]", allowed=False) def test_metareader_read(client: TestClient, admin, index_docs): @@ -71,22 +68,10 @@ def test_metareader_read(client: TestClient, admin, index_docs): create_index_metareader(client, index_docs, admin) set_metareader_access(client, index_docs, admin, "read") check_allowed(client, index_docs, field="text", allowed=True) - check_allowed(client, index_docs, snippet="text", allowed=True) + check_allowed(client, index_docs, field="text[150;3;50]", allowed=True) def test_metareader_snippet(client: TestClient, admin, index_docs): - """ - Set text field to metareader_access=snippet - Meta reader should be able to get field as snippet, but not full - """ - create_index_metareader(client, index_docs, admin) - - set_metareader_access(client, index_docs, admin, "snippet") - check_allowed(client, index_docs, field="text", allowed=False) - check_allowed(client, index_docs, snippet="text", allowed=True) - - -def test_metareader_snippet_params(client: TestClient, admin, index_docs): """ Set text field to metareader_access=snippet[50;1;20] Metareader should only be able to get field as snippet @@ -96,7 +81,6 @@ def test_metareader_snippet_params(client: TestClient, admin, index_docs): set_metareader_access(client, index_docs, admin, "snippet[50;1;20]") check_allowed(client, index_docs, field="text", allowed=False) - check_allowed(client, index_docs, snippet="text", allowed=False) - check_allowed(client, index_docs, snippet="text[51;1;20]", allowed=False) - check_allowed(client, index_docs, snippet="text[50,1,20]", allowed=True) - check_allowed(client, index_docs, snippet="text[49;1;20]", allowed=True) + check_allowed(client, index_docs, field="text[51;1;20]", allowed=False) + check_allowed(client, index_docs, field="text[50,1,20]", allowed=True) + check_allowed(client, index_docs, field="text[49;1;20]", allowed=True) diff --git a/tests/test_api_query.py b/tests/test_api_query.py index eb51992..e6fae6c 100644 --- a/tests/test_api_query.py +++ b/tests/test_api_query.py @@ -3,46 +3,16 @@ from tests.conftest import upload from tests.tools import get_json, post_json, dictset -TEST_DOCUMENTS = [ - { - "cat": "a", - "subcat": "x", - "i": 1, - "date": "2018-01-01", - "text": "this is a text", - }, - { - "cat": "a", - "subcat": "x", - "i": 2, - "date": "2018-02-01", - "text": "a test text", - }, - { - "cat": "a", - "subcat": "y", - "i": 11, - "date": "2020-01-01", - "text": "and this is another test toto", - "title": "bla", - }, - { - "cat": "b", - "subcat": "y", - "i": 31, - "date": "2018-01-01", - "text": "Toto je testovací článek", - "title": "more bla", - }, -] - def test_query_get(client, index_docs, user): """Can we run a simple query?""" def q(**query_string): return get_json( - client, f"/index/{index_docs}/documents", user=user, params=query_string + client, + f"/index/{index_docs}/documents", + user=user, + params=query_string, )["results"] def qi(**query_string): @@ -63,8 +33,8 @@ def qi(**query_string): assert qi(date__gte="2018-02-01", date__lt="2020-01-01") == {1} # Can we request specific fields? - default_fields = {"_id", "date", "title"} - assert set(q()[0].keys()) == default_fields + all_fields = {"_id", "cat", "subcat", "i", "date", "text", "title"} + assert set(q()[0].keys()) == all_fields assert set(q(fields="cat")[0].keys()) == {"_id", "cat"} assert set(q(fields="date,title")[0].keys()) == {"_id", "date", "title"} @@ -94,8 +64,8 @@ def qi(**query_string): assert qi(filters={"cat": {"values": ["a"]}}) == {0, 1, 2} # Can we request specific fields? - default_fields = {"_id", "date", "title"} - assert set(q()[0].keys()) == default_fields + all_fields = {"_id", "cat", "subcat", "i", "date", "text", "title"} + assert set(q()[0].keys()) == all_fields assert set(q(fields=["cat"])[0].keys()) == {"_id", "cat"} assert set(q(fields=["date", "title"])[0].keys()) == {"_id", "date", "title"} @@ -166,21 +136,34 @@ def test_multiple_index(client, index_docs, index, user): ) indices = f"{index},{index_docs}" assert ( - len(get_json(client, f"/index/{indices}/documents", user=user)["results"]) == 5 + len( + get_json( + client, + f"/index/{indices}/documents", + user=user, + params=dict(fields="_id"), + )["results"] + ) + == 5 ) assert ( len( - post_json(client, f"/index/{indices}/query", user=user, expected=200)[ - "results" - ] + post_json( + client, + f"/index/{indices}/query", + user=user, + expected=200, + json=dict(fields=["_id"]), + )["results"] ) == 5 ) + r = post_json( client, f"/index/{indices}/aggregate", user=user, - json={"axes": [{"field": "cat"}]}, + json={"axes": [{"field": "cat"}], "fields": ["_id"]}, expected=200, ) assert dictset(r["data"]) == dictset( From b6294a1236415eaff9b17e0042669234a712b63a Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Thu, 11 Jan 2024 18:17:31 +0100 Subject: [PATCH 08/80] typing hell --- .gitignore | 3 +- amcat4/__main__.py | 82 +++-------- amcat4/aggregate.py | 116 ++++++++++----- amcat4/api/index.py | 156 +++++++++----------- amcat4/elastic.py | 333 +----------------------------------------- amcat4/index.py | 270 ++++++++++++++++++++++++++++++---- amcat4/models.py | 20 ++- tests/test_elastic.py | 80 +++++----- 8 files changed, 473 insertions(+), 587 deletions(-) diff --git a/.gitignore b/.gitignore index 0c9bf0f..7440fed 100644 --- a/.gitignore +++ b/.gitignore @@ -49,8 +49,7 @@ nosetests.xml # PyCharm meuk .idea -# vscode meuk -#.vscode +# vscode meuk (only include extensions and settings) .vscode/* !.vscode/settings.json !.vscode/extensions.json diff --git a/amcat4/__main__.py b/amcat4/__main__.py index 6873810..bfd3b56 100644 --- a/amcat4/__main__.py +++ b/amcat4/__main__.py @@ -19,14 +19,8 @@ from amcat4 import index from amcat4.config import get_settings, AuthOptions, validate_settings -from amcat4.elastic import connect_elastic, get_system_version, ping, upload_documents -from amcat4.index import ( - GLOBAL_ROLES, - create_index, - set_global_role, - Role, - list_global_users, -) +from amcat4.elastic import connect_elastic, get_system_version, ping +from amcat4.index import GLOBAL_ROLES, create_index, set_global_role, Role, list_global_users, upload_documents SOTU_INDEX = "state_of_the_union" @@ -58,13 +52,10 @@ def upload_test_data() -> str: def run(args): auth = get_settings().auth - logging.info( - f"Starting server at port {args.port}, debug={not args.nodebug}, auth={auth}" - ) + logging.info(f"Starting server at port {args.port}, debug={not args.nodebug}, auth={auth}") if auth == AuthOptions.no_auth: logging.warning( - "Warning: No authentication is set up - " - "everyone who can access this service can view and change all data" + "Warning: No authentication is set up - " "everyone who can access this service can view and change all data" ) if validate_settings(): logging.warning(validate_settings()) @@ -75,9 +66,7 @@ def run(args): ) if ping(): logging.info(f"Connect to elasticsearch {get_settings().elastic_host}") - uvicorn.run( - "amcat4.api:app", host="0.0.0.0", reload=not args.nodebug, port=args.port - ) + uvicorn.run("amcat4.api:app", host="0.0.0.0", reload=not args.nodebug, port=args.port) def val(val_or_list): @@ -95,24 +84,18 @@ def migrate_index(_args): logging.error(f"Cannot connect to elasticsearch server {settings.elastic_host}") sys.exit(1) if not elastic.indices.exists(index=settings.system_index): - logging.info( - "System index does not exist yet. It will be created automatically if you run the server" - ) + logging.info("System index does not exist yet. It will be created automatically if you run the server") sys.exit(1) # Check index format version version = get_system_version(elastic) - logging.info( - f"{settings.elastic_host}::{settings.system_index} is at version {version or 0}" - ) + logging.info(f"{settings.elastic_host}::{settings.system_index} is at version {version or 0}") if version == 1: logging.info("Nothing to do") else: logging.info("Migrating to version 1") fields = ["index", "email", "role"] indices = collections.defaultdict(dict) - for entry in elasticsearch.helpers.scan( - elastic, index=settings.system_index, fields=fields, _source=False - ): + for entry in elasticsearch.helpers.scan(elastic, index=settings.system_index, fields=fields, _source=False): index, email, role = [val(entry["fields"][field]) for field in fields] indices[index][email] = role if GLOBAL_ROLES not in indices: @@ -122,10 +105,7 @@ def migrate_index(_args): for index, roles_dict in indices.items(): guest_role = roles_dict.pop("_guest", None) roles_dict.pop("admin", None) - roles = [ - {"email": email, "role": role} - for (email, role) in roles_dict.items() - ] + roles = [{"email": email, "role": role} for (email, role) in roles_dict.items()] doc = dict(name=index, guest_role=guest_role, roles=roles) if index == GLOBAL_ROLES: doc["version"] = 1 @@ -191,9 +171,7 @@ def list_users(_args): for user, role in users: print(f"{role.name:10}: {user}") if not (users or admin_password): - print( - "(No users defined yet, set AMCAT4_ADMIN_PASSWORD in environment use add-admin to add users by email)" - ) + print("(No users defined yet, set AMCAT4_ADMIN_PASSWORD in environment use add-admin to add users by email)") def config_amcat(args): @@ -208,9 +186,7 @@ def config_amcat(args): fieldinfo = settings.model_fields[fieldname] validation_function = AuthOptions.validate if fieldname == "auth" else None value = getattr(settings, fieldname) - value = menu( - fieldname, fieldinfo, value, validation_function=validation_function - ) + value = menu(fieldname, fieldinfo, value, validation_function=validation_function) if value is ABORTED: return if value is not UNCHANGED: @@ -259,9 +235,7 @@ def menu(fieldname: str, fieldinfo: FieldInfo, value, validation_function=None): print(f"The current value for {bold(fieldname)} is {bold(value)}.") while True: try: - value = input( - "Enter a new value, press [enter] to leave unchanged, or press [control+c] to abort: " - ) + value = input("Enter a new value, press [enter] to leave unchanged, or press [control+c] to abort: ") except KeyboardInterrupt: return ABORTED if not value.strip(): @@ -275,9 +249,7 @@ def menu(fieldname: str, fieldinfo: FieldInfo, value, validation_function=None): def main(): parser = argparse.ArgumentParser(description=__doc__, prog="python -m amcat4") - subparsers = parser.add_subparsers( - dest="action", title="action", help="Action to perform:", required=True - ) + subparsers = parser.add_subparsers(dest="action", title="action", help="Action to perform:", required=True) p = subparsers.add_parser("run", help="Run the backend API in development mode") p.add_argument( "--no-debug", @@ -288,22 +260,14 @@ def main(): p.add_argument("-p", "--port", help="Port", default=5000) p.set_defaults(func=run) - p = subparsers.add_parser( - "create-env", help="Create the .env file with a random secret key" - ) + p = subparsers.add_parser("create-env", help="Create the .env file with a random secret key") p.add_argument("-a", "--admin_email", help="The email address of the admin user.") - p.add_argument( - "-p", "--admin_password", help="The password of the built-in admin user." - ) - p.add_argument( - "-P", "--no-admin_password", action="store_true", help="Disable admin password" - ) + p.add_argument("-p", "--admin_password", help="The password of the built-in admin user.") + p.add_argument("-P", "--no-admin_password", action="store_true", help="Disable admin password") p.set_defaults(func=create_env) - p = subparsers.add_parser( - "config", help="Configure amcat4 settings in an interactive menu." - ) + p = subparsers.add_parser("config", help="Configure amcat4 settings in an interactive menu.") p.set_defaults(func=config_amcat) p = subparsers.add_parser("add-admin", help="Add a global admin") @@ -313,21 +277,15 @@ def main(): p = subparsers.add_parser("list-users", help="List global users") p.set_defaults(func=list_users) - p = subparsers.add_parser( - "create-test-index", help=f"Create the {SOTU_INDEX} test index" - ) + p = subparsers.add_parser("create-test-index", help=f"Create the {SOTU_INDEX} test index") p.set_defaults(func=create_test_index) - p = subparsers.add_parser( - "migrate", help="Migrate the system index to the current version" - ) + p = subparsers.add_parser("migrate", help="Migrate the system index to the current version") p.set_defaults(func=migrate_index) args = parser.parse_args() - logging.basicConfig( - format="[%(levelname)-7s:%(name)-15s] %(message)s", level=logging.INFO - ) + logging.basicConfig(format="[%(levelname)-7s:%(name)-15s] %(message)s", level=logging.INFO) es_logger = logging.getLogger("elasticsearch") es_logger.setLevel(logging.WARNING) diff --git a/amcat4/aggregate.py b/amcat4/aggregate.py index 4845f2c..1a81653 100644 --- a/amcat4/aggregate.py +++ b/amcat4/aggregate.py @@ -5,8 +5,10 @@ from typing import Mapping, Iterable, Union, Tuple, Sequence, List, Dict, Optional from amcat4.date_mappings import interval_mapping -from amcat4.elastic import es, get_fields +from amcat4.elastic import es +from amcat4.index import get_fields from amcat4.query import build_body, _normalize_queries +from amcat4.models import Field def _combine_mappings(mappings): @@ -21,7 +23,8 @@ class Axis: """ Class that specifies an aggregation axis """ - def __init__(self, field: str, interval: str = None, name: str = None, field_type: str = None): + + def __init__(self, field: str, interval: str | None = None, name: str | None = None, field_type: str | None = None): self.field = field self.interval = interval self.ftype = field_type @@ -53,7 +56,7 @@ def get_value(self, values): if m := interval_mapping(self.interval): value = m.postprocess(value) elif self.ftype == "date": - value = datetime.utcfromtimestamp(value / 1000.) + value = datetime.utcfromtimestamp(value / 1000.0) if self.interval in {"year", "month", "week", "day"}: value = value.date() return value @@ -70,6 +73,7 @@ class Aggregation: """ Specification of a single aggregation, that is, field and aggregation function """ + def __init__(self, field: str, function: str, name: str = None, ftype: str = None): self.field = field self.function = function @@ -80,9 +84,9 @@ def dsl_item(self): return self.name, {self.function: {"field": self.field}} def get_value(self, bucket: dict): - result = bucket[self.name]['value'] + result = bucket[self.name]["value"] if result and self.ftype == "date": - result = datetime.utcfromtimestamp(result / 1000.) + result = datetime.utcfromtimestamp(result / 1000.0) return result def asdict(self): @@ -95,8 +99,7 @@ def aggregation_dsl(aggregations: Iterable[Aggregation]) -> dict: class AggregateResult: - def __init__(self, axes: Sequence[Axis], aggregations: List[Aggregation], - data: List[tuple], count_column: str = "n"): + def __init__(self, axes: Sequence[Axis], aggregations: List[Aggregation], data: List[tuple], count_column: str = "n"): self.axes = axes self.data = data self.aggregations = aggregations @@ -104,26 +107,34 @@ def __init__(self, axes: Sequence[Axis], aggregations: List[Aggregation], def as_dicts(self) -> Iterable[dict]: """Return the results as a sequence of {axis1, ..., n} dicts""" - keys = tuple(ax.name for ax in self.axes) + (self.count_column, ) + keys = tuple(ax.name for ax in self.axes) + (self.count_column,) if self.aggregations: keys += tuple(a.name for a in self.aggregations) for row in self.data: yield dict(zip(keys, row)) -def _bare_aggregate(index: str, queries, filters, aggregations: Sequence[Aggregation]) -> Tuple[int, dict]: +def _bare_aggregate(index: str | list[str], queries, filters, aggregations: Sequence[Aggregation]) -> Tuple[int, dict]: """ Aggregate without sources/group_by. Returns a tuple of doc count and aggregegations (doc_count, {metric: value}) """ body = build_body(queries=queries, filters=filters) if filters or queries else {} + index = index if isinstance(index, str) else ",".join(index) aresult = es().search(index=index, size=0, aggregations=aggregation_dsl(aggregations), **body) cresult = es().count(index=index, **body) - return cresult['count'], aresult['aggregations'] + return cresult["count"], aresult["aggregations"] -def _elastic_aggregate(index: Union[str, List[str]], sources, queries, filters, aggregations: Sequence[Aggregation], - runtime_mappings: Mapping[str, Mapping] = None, after_key=None) -> Iterable[dict]: +def _elastic_aggregate( + index: str | list[str], + sources, + queries, + filters, + aggregations: list[Aggregation], + runtime_mappings: dict[str, Mapping] = None, + after_key=None, +) -> Iterable[dict]: """ Recursively get all buckets from a composite query. Yields 'buckets' consisting of {key: {axis: value}, doc_count: } @@ -133,38 +144,49 @@ def _elastic_aggregate(index: Union[str, List[str]], sources, queries, filters, after = {"after": after_key} if after_key else {} aggr: Dict[str, Dict[str, dict]] = {"aggs": {"composite": dict(sources=sources, **after)}} if aggregations: - aggr["aggs"]['aggregations'] = aggregation_dsl(aggregations) + aggr["aggs"]["aggregations"] = aggregation_dsl(aggregations) kargs = {} if filters or queries: q = build_body(queries=queries.values(), filters=filters) kargs["query"] = q["query"] - result = es().search(index=index if isinstance(index, str) else ",".join(index), - size=0, aggregations=aggr, runtime_mappings=runtime_mappings, **kargs - ) + result = es().search( + index=index if isinstance(index, str) else ",".join(index), + size=0, + aggregations=aggr, + runtime_mappings=runtime_mappings, + **kargs, + ) if failure := result.get("_shards", {}).get("failures"): - raise Exception(f'Error on running aggregate search: {failure}') - yield from result['aggregations']['aggs']['buckets'] - after_key = result['aggregations']['aggs'].get('after_key') + raise Exception(f"Error on running aggregate search: {failure}") + yield from result["aggregations"]["aggs"]["buckets"] + after_key = result["aggregations"]["aggs"].get("after_key") if after_key: - yield from _elastic_aggregate(index, sources, queries, filters, aggregations, - runtime_mappings=runtime_mappings, after_key=after_key) + yield from _elastic_aggregate( + index, sources, queries, filters, aggregations, runtime_mappings=runtime_mappings, after_key=after_key + ) -def _aggregate_results(index: Union[str, List[str]], axes: List[Axis], queries: Mapping[str, str], - filters: Optional[Mapping[str, Mapping]], aggregations: List[Aggregation]) -> Iterable[tuple]: +def _aggregate_results( + index: Union[str, List[str]], + axes: List[Axis], + queries: Mapping[str, str], + filters: Optional[Mapping[str, Mapping]], + aggregations: List[Aggregation], +) -> Iterable[tuple]: if not axes: # No axes, so return aggregations (or total count) only if aggregations: count, results = _bare_aggregate(index, queries, filters, aggregations) yield (count,) + tuple(a.get_value(results) for a in aggregations) else: - result = es().count(index=index if isinstance(index, str) else ",".join(index), - **build_body(queries=queries, filters=filters)) - yield result['count'], + result = es().count( + index=index if isinstance(index, str) else ",".join(index), **build_body(queries=queries, filters=filters) + ) + yield result["count"], elif any(ax.field == "_query" for ax in axes): # Strip off _query axis and run separate aggregation for each query i = [ax.field for ax in axes].index("_query") - _axes = axes[:i] + axes[(i+1):] + _axes = axes[:i] + axes[(i + 1) :] for label, query in queries.items(): for result_tuple in _aggregate_results(index, _axes, {label: query}, filters, aggregations): # insert label into the right position on the result tuple @@ -174,16 +196,21 @@ def _aggregate_results(index: Union[str, List[str]], axes: List[Axis], queries: sources = [axis.query() for axis in axes] runtime_mappings = _combine_mappings(axis.runtime_mappings() for axis in axes) for bucket in _elastic_aggregate(index, sources, queries, filters, aggregations, runtime_mappings): - row = tuple(axis.get_value(bucket['key']) for axis in axes) - row += (bucket['doc_count'], ) + row = tuple(axis.get_value(bucket["key"]) for axis in axes) + row += (bucket["doc_count"],) if aggregations: row += tuple(a.get_value(bucket) for a in aggregations) yield row -def query_aggregate(index: Union[str, List[str]], axes: Sequence[Axis] = None, aggregations: Sequence[Aggregation] = None, *, - queries: Union[Mapping[str, str], Sequence[str]] = None, - filters: Mapping[str, Mapping] = None) -> AggregateResult: +def query_aggregate( + index: str | list[str], + axes: list[Axis] | None = None, + aggregations: list[Aggregation] | None = None, + *, + queries: Mapping[str, str] | Sequence[str] | None = None, + filters: Mapping[str, Mapping] | None = None, +) -> AggregateResult: """ Conduct an aggregate query. Note that interval queries also yield zero counts for intervening keys without value, @@ -199,15 +226,32 @@ def query_aggregate(index: Union[str, List[str]], axes: Sequence[Axis] = None, a """ if axes and len([x.field == "_query" for x in axes[1:]]) > 1: raise ValueError("Only one aggregation axis may be by query") - fields = get_fields(index) + + all_fields: dict[str, Field] = dict() + indices = index if isinstance(index, list) else [index] + for index in indices: + index_fields = get_fields(index) + for field_name, field in index_fields.items(): + if field_name not in all_fields: + all_fields[field_name] = field + else: + if field.type != all_fields[field_name].type: + raise ValueError(f"Type of {field_name} is not the same in all indices") + all_fields.update(get_fields(index)) + if not axes: axes = [] for axis in axes: - axis.ftype = "_query" if axis.field == "_query" else fields[axis.field]['type'] + axis.ftype = "_query" if axis.field == "_query" else all_fields[axis.field].type if not aggregations: aggregations = [] for aggregation in aggregations: - aggregation.ftype = fields[aggregation.field]['type'] + aggregation.ftype = all_fields[aggregation.field].type queries = _normalize_queries(queries) data = list(_aggregate_results(index, axes, queries, filters, aggregations)) - return AggregateResult(axes, aggregations, data, count_column="n", ) + return AggregateResult( + axes, + aggregations, + data, + count_column="n", + ) diff --git a/amcat4/api/index.py b/amcat4/api/index.py index b1c22fc..e15e7fe 100644 --- a/amcat4/api/index.py +++ b/amcat4/api/index.py @@ -8,44 +8,34 @@ from fastapi.params import Body, Depends from pydantic import BaseModel, ConfigDict -from amcat4 import elastic, index +from amcat4 import index from amcat4.api.auth import authenticated_user, authenticated_writer, check_role from amcat4.api.common import py2dict -from amcat4.index import ( - Index, - IndexDoesNotExist, - Role, - get_global_role, - get_index, - get_role, - list_known_indices, - list_users, -) + from amcat4.index import refresh_index as es_refresh_index from amcat4.index import refresh_system_index, remove_role, set_role +from amcat4.models import Field app_index = APIRouter(prefix="/index", tags=["index"]) -RoleType = Literal[ - "ADMIN", "WRITER", "READER", "METAREADER", "admin", "writer", "reader", "metareader" -] +RoleType = Literal["ADMIN", "WRITER", "READER", "METAREADER", "admin", "writer", "reader", "metareader"] @app_index.get("/") -def index_list(current_user: str = Depends(authenticated_user)): +def index_list(current_user=Depends(authenticated_user)): """ List index from this server. Returns a list of dicts containing name, role, and guest attributes """ - def index_to_dict(ix: Index) -> dict: - ix = ix._asdict() - ix["guest_role"] = ix["guest_role"] and ix["guest_role"].name - del ix["roles"] - return ix + def index_to_dict(ix: index.Index) -> dict: + ix_dict = ix._asdict() + ix_dict["guest_role"] = ix_dict["guest_role"] and ix_dict["guest_role"].name + del ix_dict["roles"] + return ix_dict - return [index_to_dict(ix) for ix in list_known_indices(current_user)] + return [index_to_dict(ix) for ix in index.list_known_indices(current_user)] class NewIndex(BaseModel): @@ -58,15 +48,13 @@ class NewIndex(BaseModel): @app_index.post("/", status_code=status.HTTP_201_CREATED) -def create_index( - new_index: NewIndex, current_user: str = Depends(authenticated_writer) -): +def create_index(new_index: NewIndex, current_user=Depends(authenticated_writer)): """ Create a new index, setting the current user to admin (owner). POST data should be json containing name and optional guest_role """ - guest_role = new_index.guest_role and Role[new_index.guest_role.upper()] + guest_role = new_index.guest_role and index.Role[new_index.guest_role.upper()] try: index.create_index( new_index.id, @@ -78,9 +66,7 @@ def create_index( except ApiError as e: raise HTTPException( status_code=400, - detail=dict( - info=f"Error on creating index: {e}", message=e.message, body=e.body - ), + detail=dict(info=f"Error on creating index: {e}", message=e.message, body=e.body), ) @@ -88,20 +74,18 @@ def create_index( class ChangeIndex(BaseModel): """Form to update an existing index.""" - guest_role: Optional[ - Literal[ - "ADMIN", - "WRITER", - "READER", - "METAREADER", - "admin", - "writer", - "reader", - "metareader", - "NONE", - "none", - ] - ] = "None" + guest_role: Literal[ + "ADMIN", + "WRITER", + "READER", + "METAREADER", + "admin", + "writer", + "reader", + "metareader", + "NONE", + "none", + ] | None = "None" name: Optional[str] = None description: Optional[str] = None summary_field: Optional[str] = None @@ -116,14 +100,14 @@ def modify_index(ix: str, data: ChangeIndex, user: str = Depends(authenticated_u User needs admin rights on the index """ - check_role(user, Role.ADMIN, ix) + check_role(user, index.Role.ADMIN, ix) guest_role, remove_guest_role = None, False if data.guest_role: role = data.guest_role.upper() if role == "NONE": remove_guest_role = True else: - guest_role = Role[role] + guest_role = index.Role[role] index.modify_index( ix, name=data.name, @@ -141,23 +125,19 @@ def view_index(ix: str, user: str = Depends(authenticated_user)): View the index. """ try: - role = check_role(user, Role.METAREADER, ix, required_global_role=Role.WRITER) - d = get_index(ix)._asdict() + role = check_role(user, index.Role.METAREADER, ix, required_global_role=index.Role.WRITER) + d = index.get_index(ix)._asdict() d["user_role"] = role and role.name d["guest_role"] = d["guest_role"].name if d.get("guest_role") else None return d - except IndexDoesNotExist: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=f"Index {ix} does not exist" - ) + except index.IndexDoesNotExist: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Index {ix} does not exist") -@app_index.delete( - "/{ix}", status_code=status.HTTP_204_NO_CONTENT, response_class=Response -) +@app_index.delete("/{ix}", status_code=status.HTTP_204_NO_CONTENT, response_class=Response) def delete_index(ix: str, user: str = Depends(authenticated_user)): """Delete the index.""" - check_role(user, Role.ADMIN, ix) + check_role(user, index.Role.ADMIN, ix) index.delete_index(ix) @@ -175,9 +155,7 @@ class Document(BaseModel): def upload_documents( ix: str, documents: List[Document] = Body(None, description="The documents to upload"), - columns: Optional[Mapping[str, str]] = Body( - None, description="Optional Specification of field (column) types" - ), + columns: Optional[Mapping[str, str]] = Body(None, description="Optional Specification of field (column) types"), user: str = Depends(authenticated_user), ): """ @@ -190,9 +168,9 @@ def upload_documents( } Returns a list of ids for the uploaded documents """ - check_role(user, Role.WRITER, ix) + check_role(user, index.Role.WRITER, ix) documents = [py2dict(doc) for doc in documents] - return elastic.upload_documents(ix, documents, columns) + return index.upload_documents(ix, documents, columns) @app_index.get("/{ix}/documents/{docid}") @@ -208,12 +186,12 @@ def get_document( GET request parameters: fields - Comma separated list of fields to return (default: all fields) """ - check_role(user, Role.READER, ix) + check_role(user, index.Role.READER, ix) kargs = {} if fields: kargs["_source"] = fields try: - return elastic.get_document(ix, docid, **kargs) + return index.get_document(ix, docid, **kargs) except elasticsearch.exceptions.NotFoundError: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -237,9 +215,9 @@ def update_document( PUT request body should be a json {field: value} mapping of fields to update """ - check_role(user, Role.WRITER, ix) + check_role(user, index.Role.WRITER, ix) try: - elastic.update_document(ix, docid, update) + index.update_document(ix, docid, update) except elasticsearch.exceptions.NotFoundError: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -254,9 +232,9 @@ def update_document( ) def delete_document(ix: str, docid: str, user: str = Depends(authenticated_user)): """Delete this document.""" - check_role(user, Role.WRITER, ix) + check_role(user, index.Role.WRITER, ix) try: - elastic.delete_document(ix, docid) + index.delete_document(ix, docid) except elasticsearch.exceptions.NotFoundError: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -271,31 +249,29 @@ def get_fields(ix: str, user=Depends(authenticated_user)): Returns a json array of {name, type} objects """ - check_role(user, Role.METAREADER, ix) + check_role(user, index.Role.METAREADER, ix) if "," in ix: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"/[index]/fields does not support multiple indices", ) - return elastic.get_fields(ix) + return index.get_fields(ix) @app_index.post("/{ix}/fields") -def set_fields( - ix: str, body: dict = Body(...), user: str = Depends(authenticated_user) -): +def set_fields(ix: str, fields=dict[str, Field], user=Depends(authenticated_user)): """ Set the field types used in this index. POST body should be a dict of {field: type} or {field: {type: type, meta: meta}} """ - check_role(user, Role.WRITER, ix) - elastic.set_fields(ix, body) + check_role(user, index.Role.WRITER, ix) + index.set_fields(ix, fields) return "", HTTPStatus.NO_CONTENT @app_index.get("/{ix}/fields/{field}/values") -def get_field_values(ix: str, field: str, user: str = Depends(authenticated_user)): +def get_field_values(ix: str, field: str, user=Depends(authenticated_user)): """ Get unique values for a specific field. Should mainly/only be used for tag fields. Main purpose is to provide a list of values for a dropdown menu. @@ -305,8 +281,8 @@ def get_field_values(ix: str, field: str, user: str = Depends(authenticated_user there should be a limit. Querying could be an option, but not sure if that is efficient, since elastic has to aggregate all values first. """ - check_role(user, Role.READER, ix) - values = elastic.get_field_values(ix, field, size=2001) + check_role(user, index.Role.READER, ix) + values = index.get_field_values(ix, field, size=2001) if len(values) > 2000: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -316,30 +292,30 @@ def get_field_values(ix: str, field: str, user: str = Depends(authenticated_user @app_index.get("/{ix}/fields/{field}/stats") -def get_field_stats(ix: str, field: str, user: str = Depends(authenticated_user)): +def get_field_stats(ix: str, field: str, user=Depends(authenticated_user)): """Get statistics for a specific value. Only works for numeric (incl date) fields.""" - check_role(user, Role.READER, ix) - return elastic.get_field_stats(ix, field) + check_role(user, index.Role.READER, ix) + return index.get_field_stats(ix, field) @app_index.get("/{ix}/users") -def list_index_users(ix: str, user: str = Depends(authenticated_user)): +def list_index_users(ix: str, user=Depends(authenticated_user)): """ List the users in this index. Allowed for global admin and local readers """ - if get_global_role(user) != Role.ADMIN: - check_role(user, Role.READER, ix) - return [{"email": u, "role": r.name} for (u, r) in list_users(ix).items()] + if index.get_global_role(user) != index.Role.ADMIN: + check_role(user, index.Role.READER, ix) + return [{"email": u, "role": r.name} for (u, r) in index.list_users(ix).items()] def _check_can_modify_user(ix, user, target_user, target_role): - if get_global_role(user) != Role.ADMIN: + if index.get_global_role(user) != index.Role.ADMIN: required_role = ( - Role.ADMIN - if (target_role == Role.ADMIN or get_role(ix, target_user) == Role.ADMIN) - else Role.WRITER + index.Role.ADMIN + if (target_role == index.Role.ADMIN or index.get_role(ix, target_user) == index.Role.ADMIN) + else index.Role.WRITER ) check_role(user, required_role, ix) @@ -357,7 +333,7 @@ def add_index_users( To create regular users you need WRITER permission. To create ADMIN users, you need ADMIN permission. Global ADMINs can always add users. """ - r = Role[role] + r = index.Role[role] _check_can_modify_user(ix, user, email, r) set_role(ix, email, r) return {"user": email, "index": ix, "role": r.name} @@ -376,7 +352,7 @@ def modify_index_user( This requires WRITER rights on the index or global ADMIN rights. If changing a user from or to ADMIN, it requires (local or global) ADMIN rights """ - r = Role[role] + r = index.Role[role] _check_can_modify_user(ix, user, email, r) set_role(ix, email, r) return {"user": email, "index": ix, "role": r.name} @@ -395,8 +371,6 @@ def remove_index_user(ix: str, email: str, user: str = Depends(authenticated_use return {"user": email, "index": ix, "role": None} -@app_index.get( - "/{ix}/refresh", status_code=status.HTTP_204_NO_CONTENT, response_class=Response -) +@app_index.get("/{ix}/refresh", status_code=status.HTTP_204_NO_CONTENT, response_class=Response) def refresh_index(ix: str): es_refresh_index(ix) diff --git a/amcat4/elastic.py b/amcat4/elastic.py index f6e7d9b..8ee6908 100644 --- a/amcat4/elastic.py +++ b/amcat4/elastic.py @@ -7,58 +7,18 @@ - The system index contains a 'document' for each used index containing: {auth: [{email: role}], guest_role: role} - We define the mappings (field types) based on existing elasticsearch mappings, - but use field metadata to define specific fields, see ES_MAPPINGS below. + but use field metadata to define specific fields. """ import functools -import hashlib -import json + import logging -import re -from typing import Mapping, List, Iterable, Optional, Tuple, Union, Sequence, Literal +from typing import Optional from elasticsearch import Elasticsearch, NotFoundError -from elasticsearch.helpers import bulk -from amcat4.util import parse_field - from amcat4.config import get_settings SYSTEM_INDEX_VERSION = 1 -ES_MAPPINGS = { - "long": {"type": "long", "meta": {"metareader_access": "read"}}, - "date": { - "type": "date", - "format": "strict_date_optional_time", - "meta": {"metareader_access": "read"}, - }, - "double": {"type": "double", "meta": {"metareader_access": "read"}}, - "keyword": {"type": "keyword", "meta": {"metareader_access": "read"}}, - "url": { - "type": "keyword", - "meta": {"amcat4_type": "url", "metareader_access": "read"}, - }, - "tag": { - "type": "keyword", - "meta": {"amcat4_type": "tag", "metareader_access": "read"}, - }, - "id": { - "type": "keyword", - "meta": {"amcat4_type": "id", "metareader_access": "read"}, - }, - "text": {"type": "text", "meta": {"metareader_access": "none"}}, - "object": {"type": "object", "meta": {"metareader_access": "none"}}, - "geo_point": {"type": "geo_point", "metareader_access": "read"}, - "dense_vector": {"type": "dense_vector", "metareader_access": "none"}, -} - -DEFAULT_MAPPING = { - "text": ES_MAPPINGS["text"], - "title": ES_MAPPINGS["text"], - "date": ES_MAPPINGS["date"], - "url": ES_MAPPINGS["url"], -} - - SYSTEM_MAPPING = { "name": {"type": "text"}, "description": {"type": "text"}, @@ -77,9 +37,7 @@ def es() -> Elasticsearch: try: return _setup_elastic() except ValueError as e: - raise ValueError( - f"Cannot connect to elastic {get_settings().elastic_host!r}: {e}" - ) + raise ValueError(f"Cannot connect to elastic {get_settings().elastic_host!r}: {e}") def connect_elastic() -> Elasticsearch: @@ -108,9 +66,7 @@ def get_system_version(elastic=None) -> Optional[int]: if elastic is None: elastic = es() try: - r = elastic.get( - index=settings.system_index, id=GLOBAL_ROLES, source_includes="version" - ) + r = elastic.get(index=settings.system_index, id=GLOBAL_ROLES, source_includes="version") except NotFoundError: return None return r["_source"]["version"] @@ -130,9 +86,7 @@ def _setup_elastic(): ) elastic = connect_elastic() if not elastic.ping(): - raise CannotConnectElastic( - f"Cannot connect to elasticsearch server {settings.elastic_host}" - ) + raise CannotConnectElastic(f"Cannot connect to elasticsearch server {settings.elastic_host}") if elastic.indices.exists(index=settings.system_index): # Check index format version if get_system_version(elastic) is None: @@ -143,9 +97,7 @@ def _setup_elastic(): else: logging.info(f"Creating amcat4 system index: {settings.system_index}") - elastic.indices.create( - index=settings.system_index, mappings={"properties": SYSTEM_MAPPING} - ) + elastic.indices.create(index=settings.system_index, mappings={"properties": SYSTEM_MAPPING}) elastic.index( index=settings.system_index, id=GLOBAL_ROLES, @@ -154,277 +106,6 @@ def _setup_elastic(): return elastic -def coerce_type_to_elastic(value, ftype): - """ - Coerces values into the respective type in elastic - based on ES_MAPPINGS and elastic field types - """ - if ftype in ["keyword", "constant_keyword", "wildcard", "url", "tag", "text"]: - value = str(value) - elif ftype in [ - "long", - "short", - "byte", - "double", - "float", - "half_float", - "half_float", - "unsigned_long", - ]: - value = float(value) - elif ftype in ["integer"]: - value = int(value) - elif ftype == "boolean": - value = bool(value) - return value - - -def _get_hash(document: dict) -> bytes: - """ - Get the hash for a document - """ - hash_str = json.dumps( - document, sort_keys=True, ensure_ascii=True, default=str - ).encode("ascii") - m = hashlib.sha224() - m.update(hash_str) - return m.hexdigest() - - -def upload_documents(index: str, documents, fields: Mapping[str, str] = None) -> None: - """ - Upload documents to this index - - :param index: The name of the index (without prefix) - :param documents: A sequence of article dictionaries - :param fields: A mapping of field:type for field types - """ - - def es_actions(index, documents): - field_types = get_index_fields(index) - for document in documents: - for key in document.keys(): - if key in field_types: - document[key] = coerce_type_to_elastic( - document[key], field_types[key].get("type") - ) - if "_id" not in document: - document["_id"] = _get_hash(document) - yield {"_index": index, **document} - - if fields: - set_fields(index, fields) - - actions = list(es_actions(index, documents)) - bulk(es(), actions) - - -def get_field_mapping(current: dict, update: Union[str, dict]): - if isinstance(update, str): - type = update - newmeta = None - else: - if "type" not in update: - type = current.get("type") - if type is None: - raise ValueError("Field type is not specified") - type = update.get("type") - if type not in ES_MAPPINGS: - raise ValueError(f"Invalid field type: {type}") - newmeta = update.get("meta", None) - - if "meta" in current: - meta = current["meta"] - else: - meta = ES_MAPPINGS[type].get("meta", {}) - if newmeta: - meta.update(newmeta) - - return dict(type=type, meta=meta) - - -def validate_field_meta(meta: dict): - """ - Elastic has limited available field meta. Here we validate the allowed keys (and values) - """ - valid_fields = ["amcat4_type", "metareader_access", "client_display"] - - for meta_field in meta.keys(): - # Validate keys - if meta_field not in valid_fields: - raise ValueError(f"Invalid meta field: {meta_field}") - - # Validate values - if not isinstance(meta[meta_field], str): - raise ValueError("Meta field value has to be a string") - - if meta_field == "amcat4_type": - if meta[meta_field] not in ES_MAPPINGS.keys(): - raise ValueError(f"Invalid amcat4_type value: {meta[meta_field]}") - - if meta_field == "client_display": - # client_display only concerns the client - continue - - if meta_field == "metareader_access": - # metareader_access can be "none", "read", or "snippet" - # if snippet, can also include the maximum snippet parameters (nomatch_chars, max_matches, match_chars) - # in the format: snippet[nomatch_chars;max_matches;match_chars] - reg = r"^(read|none|snippet\[\d+;\d+;\d+\])$" - if not re.match(reg, meta[meta_field]): - raise ValueError(f"Invalid metareader_access value: {meta[meta_field]}") - - return meta - - -def set_fields(index: str, fields: Mapping[str, Union[str, dict]]): - """ - Update the column types for this index - - :param index: The name of the index (without prefix) - :param fields: A mapping of field:type for column types - """ - index_fields = get_index_fields(index) - properties = { - field: get_field_mapping(index_fields.get(field, {}), update) - for (field, update) in fields.items() - } - es().indices.put_mapping(index=index, properties=properties) - - -def get_document(index: str, doc_id: str, **kargs) -> dict: - """ - Get a single document from this index. - - :param index: The name of the index - :param doc_id: The document id (hash) - :return: the source dict of the document - """ - return es().get(index=index, id=doc_id, **kargs)["_source"] - - -def update_document(index: str, doc_id: str, fields: dict): - """ - Update a single document. - - :param index: The name of the index - :param doc_id: The document id (hash) - :param fields: a {field: value} mapping of fields to update - """ - # Mypy doesn't understand that body= has been deprecated already... - es().update(index=index, id=doc_id, doc=fields) # type: ignore - - -def delete_document(index: str, doc_id: str): - """ - Delete a single document - - :param index: The name of the index - :param doc_id: The document id (hash) - """ - es().delete(index=index, id=doc_id) - - -def _get_type_from_property(properties: dict) -> str: - """ - Convert an elastic 'property' into an amcat4 field type - """ - result = properties.get("meta", {}).get("amcat4_type") - properties["type"] = properties.get("type", "object") - if result: - return result - return properties["type"] - - -def _get_fields(index: str) -> Iterable[Tuple[str, dict]]: - r = es().indices.get_mapping(index=index) - - for k, v in r[index]["mappings"]["properties"].items(): - t = dict(name=k, type=_get_type_from_property(v)) - if meta := v.get("meta"): - t["meta"] = meta - yield k, t - - -def get_index_fields(index: str) -> Mapping[str, dict]: - """ - Get the field types in use in this index - :param index: - :return: a dict of fieldname: field objects {fieldname: {name, type, meta, ...}] - """ - return dict(_get_fields(index)) - - -def get_fields(index: str) -> Mapping[str, dict]: - """ - Get the field types in use in this index or indices - :param index: name(s) of index(es) to query - :return: a dict of fieldname: field objects {fieldname: {name, type, ...}] - """ - if not isinstance(index, str): - raise ValueError("get_fields only supports a single index") - - return get_index_fields(index) - - -def get_field_values(index: str, field: str, size: int) -> List[str]: - """ - Get the values for a given field (e.g. to populate list of filter values on keyword field) - Results are sorted descending by document frequency - see: https://www.elastic.co/guide/en/elasticsearch/reference/7.4/search-aggregations-bucket-terms-aggregation.html#search-aggregations-bucket-terms-aggregation-order - - :param index: The index - :param field: The field name - :return: A list of values - """ - aggs = {"unique_values": {"terms": {"field": field, "size": size}}} - r = es().search(index=index, size=0, aggs=aggs) - return [x["key"] for x in r["aggregations"]["unique_values"]["buckets"]] - - -def get_field_stats(index: str, field: str) -> List[str]: - """ - Get field statistics, such as min, max, avg, etc. - :param index: The index - :param field: The field name - :return: A list of values - """ - aggs = {"facets": {"stats": {"field": field}}} - r = es().search(index=index, size=0, aggs=aggs) - return r["aggregations"]["facets"] - - -def update_by_query(index: str, script: str, query: dict, params: dict = None): - script = dict(source=script, lang="painless", params=params or {}) - es().update_by_query(index=index, script=script, **query) - - -TAG_SCRIPTS = dict( - add=""" - if (ctx._source[params.field] == null) { - ctx._source[params.field] = [params.tag] - } else if (!ctx._source[params.field].contains(params.tag)) { - ctx._source[params.field].add(params.tag) - } - """, - remove=""" - if (ctx._source[params.field] != null && ctx._source[params.field].contains(params.tag)) { - ctx._source[params.field].removeAll([params.tag]); - if (ctx._source[params.field].size() == 0) { - ctx._source.remove(params.field); - } - }""", -) - - -def update_tag_by_query( - index: str, action: Literal["add", "remove"], query: dict, field: str, tag: str -): - script = TAG_SCRIPTS[action] - params = dict(field=field, tag=tag) - update_by_query(index, script, query, params) - - def ping(): """ Can we reach this elasticsearch server diff --git a/amcat4/index.py b/amcat4/index.py index cff55ce..45c06f3 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -31,14 +31,55 @@ """ import collections from enum import IntEnum -from typing import Dict, Iterable, List, Optional +from typing import Dict, Iterable, List, Optional, Literal + +import hashlib +import json import elasticsearch.helpers from elasticsearch import NotFoundError from amcat4.config import get_settings -from amcat4.elastic import DEFAULT_MAPPING, es, get_fields -from amcat4.models import FieldSettings, updateFieldSettings +from amcat4.elastic import es +from amcat4.models import Field, UpdateField, updateField, FieldMetareaderAccess + + +# The Field model has a type field as we use it in amcat, but we need to +# convert this to an elastic type. This is the mapping +ES_MAPPINGS = { + "long": {"type": "long"}, + "date": {"type": "date", "format": "strict_date_optional_time"}, + "double": {"type": "double"}, + "keyword": {"type": "keyword"}, + "url": {"type": "keyword"}, + "tag": {"type": "keyword"}, + "id": {"type": "keyword"}, + "text": {"type": "text"}, + "object": {"type": "object"}, + "geo_point": {"type": "geo_point"}, + "dense_vector": {"type": "dense_vector"}, +} + +DEFAULT_METAREADER = { + "long": FieldMetareaderAccess(access="read"), + "date": FieldMetareaderAccess(access="read"), + "double": FieldMetareaderAccess(access="read"), + "keyword": FieldMetareaderAccess(access="read"), + "url": FieldMetareaderAccess(access="read"), + "tag": FieldMetareaderAccess(access="read"), + "id": FieldMetareaderAccess(access="read"), + "text": FieldMetareaderAccess(access="none"), + "object": FieldMetareaderAccess(access="none"), + "geo_point": FieldMetareaderAccess(access="none"), + "dense_vector": FieldMetareaderAccess(access="none"), +} + +DEFAULT_INDEX_FIELDS = { + "text": Field(type="text", metareader_access=DEFAULT_METAREADER["text"]), + "title": Field(type="text", metareader_access=DEFAULT_METAREADER["text"]), + "date": Field(type="date", metareader_access=DEFAULT_METAREADER["date"]), + "url": Field(type="url", metareader_access=DEFAULT_METAREADER["url"]), +} class Role(IntEnum): @@ -75,7 +116,7 @@ def refresh_system_index(): es().indices.refresh(index=get_settings().system_index) -def list_known_indices(email: str = None) -> Iterable[Index]: +def list_known_indices(email: str | None = None) -> Iterable[Index]: """ List all known indices, e.g. indices registered in this amcat4 instance :param email: if given, only list indices visible to this user @@ -110,10 +151,10 @@ def _index_from_elastic(index): def get_index(index: str) -> Index: try: - index = es().get(index=get_settings().system_index, id=index) + index_resp = es().get(index=get_settings().system_index, id=index) except NotFoundError: raise IndexDoesNotExist(index) - return _index_from_elastic(index) + return _index_from_elastic(index_resp) def create_index( @@ -126,7 +167,7 @@ def create_index( """ Create a new index in elasticsearch and register it with this AmCAT instance """ - es().indices.create(index=index, mappings={"properties": DEFAULT_MAPPING}) + set_fields(index, DEFAULT_INDEX_FIELDS) register_index( index, guest_role=guest_role, @@ -225,7 +266,7 @@ def set_role(index: str, email: str, role: Optional[Role]): ) -def set_global_role(email: str, role: Role): +def set_global_role(email: str, role: Role | None): """ Set the global role for this user """ @@ -239,35 +280,52 @@ def set_guest_role(index: str, guest_role: Optional[Role]): modify_index(index, guest_role=guest_role, remove_guest_role=(guest_role is None)) -def _fields_settings_to_elastic(fields_settings: Dict[str, FieldSettings]) -> List[Dict]: - return [{"field": field, "settings": settings} for field, settings in fields_settings.items()] +def _fields_to_elastic(fields: Dict[str, Field]) -> List[Dict]: + return [{"field": field, "settings": settings} for field, settings in fields.items()] -def _fields_settings_from_elastic( - fields_settings: List[Dict], -) -> Dict[str, FieldSettings]: - return {fs["field"]: fs["settings"] for fs in fields_settings} +def _fields_from_elastic( + fields: List[Dict], +) -> Dict[str, Field]: + return {fs["field"]: fs["settings"] for fs in fields} -def set_fields_settings(index: str, new_fields_settings: Dict[str, FieldSettings]): +def set_fields(index: str, new_fields: dict[str, Field] | dict[str, UpdateField]): """ - Set the fields settings for this index + Set the fields settings for this index. + + Note that if UpdateField also updates the field type, we need to update the mapping as well. """ system_index = get_settings().system_index try: - d = es().get(index=system_index, id=index, source_includes="fields_settings") + d = es().get(index=system_index, id=index, source_includes="fields") except NotFoundError: raise ValueError(f"Index {index} is not registered") - fields_settings = _fields_settings_from_elastic(d["_source"].get("fields_settings", {})) - for field, new_settings in new_fields_settings.items(): - current: FieldSettings = fields_settings.get(field, FieldSettings()) - fields_settings[field] = updateFieldSettings(current, new_settings) + type_mappings = {} + fields = _fields_from_elastic(d["_source"].get("fields", {})) + + for field, new_settings in new_fields.items(): + if new_settings.type is not None: + # if new type is specified, we need to update the index mapping properties + type_mappings[field] = ES_MAPPINGS[new_settings.type] + + current = fields.get(field) + if current is None: + # Create field + if new_settings.type is None: + raise ValueError(f"Field {field} does not yet exist, and to create a new field you need to specify a type") + fields[field] = Field(**new_settings.model_dump()) + else: + # Update field + fields[field] = updateField(current, new_settings) + + es().indices.put_mapping(index=index, properties=type_mappings) es().update( index=system_index, id=index, - doc=dict(roles=_fields_settings_to_elastic(fields_settings)), + doc=dict(roles=_fields_to_elastic(fields)), ) @@ -289,8 +347,8 @@ def modify_index( f = get_fields(index) if summary_field not in f: raise ValueError(f"Summary field {summary_field} does not exist!") - if f[summary_field]["type"] not in ["date", "keyword", "tag"]: - raise ValueError(f"Summary field {summary_field} should be date, keyword or tag, not {f[summary_field]['type']}!") + if f[summary_field].type not in ["date", "keyword", "tag"]: + raise ValueError(f"Summary field {summary_field} should be date, keyword or tag, not {f[summary_field].type}!") doc = {x: v for (x, v) in doc.items() if v} if remove_guest_role: doc["guest_role"] = None @@ -331,9 +389,10 @@ def get_role(index: str, email: str) -> Optional[Role]: return role if index == GLOBAL_ROLES: return None - role = doc["_source"].get("guest_role", None) - if role and role.lower() != "none": - return Role[role] + + guest_role: str | None = doc["_source"].get("guest_role", None) + if guest_role and guest_role.lower() != "none": + return Role[guest_role] return None @@ -369,7 +428,7 @@ def get_global_role(email: str, only_es: bool = False) -> Optional[Role]: return get_role(index=GLOBAL_ROLES, email=email) -def get_fields_settings(index: str) -> Dict[str, FieldSettings]: +def get_fields(index: str) -> Dict[str, Field]: """ Retrieve the fields settings for this index """ @@ -377,11 +436,11 @@ def get_fields_settings(index: str) -> Dict[str, FieldSettings]: d = es().get( index=get_settings().system_index, id=index, - source_includes="fields_settings", + source_includes="fields", ) except NotFoundError: raise IndexDoesNotExist(index) - return _fields_settings_from_elastic(d["_source"].get("fields_settings", {})) + return _fields_from_elastic(d["_source"].get("fields", {})) def list_users(index: str) -> Dict[str, Role]: @@ -407,3 +466,154 @@ def delete_user(email: str) -> None: set_global_role(email, None) for ix in list_known_indices(email): set_role(ix.id, email, None) + + +def coerce_type_to_elastic(value, ftype): + """ + Coerces values into the respective type in elastic + based on ES_MAPPINGS and elastic field types + """ + if ftype in ["keyword", "constant_keyword", "wildcard", "url", "tag", "text"]: + value = str(value) + elif ftype in [ + "long", + "short", + "byte", + "double", + "float", + "half_float", + "half_float", + "unsigned_long", + ]: + value = float(value) + elif ftype in ["integer"]: + value = int(value) + elif ftype == "boolean": + value = bool(value) + return value + + +def _get_hash(document: dict) -> str: + """ + Get the hash for a document + """ + hash_str = json.dumps(document, sort_keys=True, ensure_ascii=True, default=str).encode("ascii") + m = hashlib.sha224() + m.update(hash_str) + return m.hexdigest() + + +def upload_documents(index: str, documents, fields: dict[str, Field] | None) -> None: + """ + Upload documents to this index + + :param index: The name of the index (without prefix) + :param documents: A sequence of article dictionaries + :param fields: A mapping of field:type for field types + """ + + def es_actions(index, documents): + field_types = get_fields(index) + for document in documents: + for key in document.keys(): + if key not in field_types: + raise ValueError(f"The type for field {key} is not yet specified") + document[key] = coerce_type_to_elastic(document[key], field_types[key].type) + if "_id" not in document: + document["_id"] = _get_hash(document) + yield {"_index": index, **document} + + if fields: + set_fields(index, fields) + + actions = list(es_actions(index, documents)) + elasticsearch.helpers.bulk(es(), actions) + + +def get_document(index: str, doc_id: str, **kargs) -> dict: + """ + Get a single document from this index. + + :param index: The name of the index + :param doc_id: The document id (hash) + :return: the source dict of the document + """ + return es().get(index=index, id=doc_id, **kargs)["_source"] + + +def update_document(index: str, doc_id: str, fields: dict): + """ + Update a single document. + + :param index: The name of the index + :param doc_id: The document id (hash) + :param fields: a {field: value} mapping of fields to update + """ + # Mypy doesn't understand that body= has been deprecated already... + es().update(index=index, id=doc_id, doc=fields) # type: ignore + + +def delete_document(index: str, doc_id: str): + """ + Delete a single document + + :param index: The name of the index + :param doc_id: The document id (hash) + """ + es().delete(index=index, id=doc_id) + + +def get_field_values(index: str, field: str, size: int) -> List[str]: + """ + Get the values for a given field (e.g. to populate list of filter values on keyword field) + Results are sorted descending by document frequency + see: https://www.elastic.co/guide/en/elasticsearch/reference/7.4/search-aggregations-bucket-terms-aggregation.html#search-aggregations-bucket-terms-aggregation-order + + :param index: The index + :param field: The field name + :return: A list of values + """ + aggs = {"unique_values": {"terms": {"field": field, "size": size}}} + r = es().search(index=index, size=0, aggs=aggs) + return [x["key"] for x in r["aggregations"]["unique_values"]["buckets"]] + + +def get_field_stats(index: str, field: str) -> List[str]: + """ + Get field statistics, such as min, max, avg, etc. + :param index: The index + :param field: The field name + :return: A list of values + """ + aggs = {"facets": {"stats": {"field": field}}} + r = es().search(index=index, size=0, aggs=aggs) + return r["aggregations"]["facets"] + + +def update_by_query(index: str, script: str, query: dict, params: dict = None): + script = dict(source=script, lang="painless", params=params or {}) + es().update_by_query(index=index, script=script, **query) + + +TAG_SCRIPTS = dict( + add=""" + if (ctx._source[params.field] == null) { + ctx._source[params.field] = [params.tag] + } else if (!ctx._source[params.field].contains(params.tag)) { + ctx._source[params.field].add(params.tag) + } + """, + remove=""" + if (ctx._source[params.field] != null && ctx._source[params.field].contains(params.tag)) { + ctx._source[params.field].removeAll([params.tag]); + if (ctx._source[params.field].size() == 0) { + ctx._source.remove(params.field); + } + }""", +) + + +def update_tag_by_query(index: str, action: Literal["add", "remove"], query: dict, field: str, tag: str): + script = TAG_SCRIPTS[action] + params = dict(field=field, tag=tag) + update_by_query(index, script, query, params) diff --git a/amcat4/models.py b/amcat4/models.py index aed6c68..2213c28 100644 --- a/amcat4/models.py +++ b/amcat4/models.py @@ -1,5 +1,5 @@ from pydantic import BaseModel -from typing import Optional +from typing import Literal class SnippetParams(BaseModel): @@ -17,17 +17,25 @@ class SnippetParams(BaseModel): class FieldMetareaderAccess(BaseModel): """Metareader access for a specific field.""" - access: bool - max_snippet: Optional[SnippetParams] + access: Literal["none", "read", "snippet"] + max_snippet: SnippetParams | None = None -class FieldSettings(BaseModel): +class Field(BaseModel): """Settings for a field.""" - metareader_access: Optional[FieldMetareaderAccess] = None + type: str + metareader_access: FieldMetareaderAccess -def updateFieldSettings(field: FieldSettings, update: FieldSettings): +class UpdateField(BaseModel): + """Model for updating a field""" + + type: str | None = None + metareader_access: FieldMetareaderAccess | None = None + + +def updateField(field: Field, update: UpdateField | Field): for key in field.model_fields_set: setattr(field, key, getattr(update, key)) return field diff --git a/tests/test_elastic.py b/tests/test_elastic.py index d6d271e..e00fedc 100644 --- a/tests/test_elastic.py +++ b/tests/test_elastic.py @@ -1,35 +1,45 @@ from datetime import datetime -from amcat4 import elastic -from amcat4.elastic import get_fields -from amcat4.index import refresh_index +from amcat4.index import ( + refresh_index, + upload_documents, + get_document, + set_fields, + get_fields, + update_document, + update_tag_by_query, + get_field_values, +) from amcat4.query import query_documents from tests.conftest import upload def test_upload_retrieve_document(index): """Can we upload and retrieve documents""" - a = dict(text="text", title="title", date="2021-03-09", _id="test", term_tfidf=[ - {"term": "test", "value": 0.2}, - {"term": "value", "value": 0.3} - ]) - elastic.upload_documents(index, [a]) - d = elastic.get_document(index, "test") - assert d['title'] == a['title'] - assert d['term_tfidf'] == a['term_tfidf'] + a = dict( + text="text", + title="title", + date="2021-03-09", + _id="test", + term_tfidf=[{"term": "test", "value": 0.2}, {"term": "value", "value": 0.3}], + ) + upload_documents(index, [a]) + d = get_document(index, "test") + assert d["title"] == a["title"] + assert d["term_tfidf"] == a["term_tfidf"] # TODO: should a['date'] be a datetime? def test_data_coerced(index): """Are field values coerced to the correct field type""" - elastic.set_fields(index, {"i": "long"}) + set_fields(index, {"i": "long"}) a = dict(_id="DoccyMcDocface", text="text", title="test-numeric", date="2022-12-13", i="1") - elastic.upload_documents(index, [a]) - d = elastic.get_document(index, "DoccyMcDocface") + upload_documents(index, [a]) + d = get_document(index, "DoccyMcDocface") assert isinstance(d["i"], float) a = dict(text="text", title=1, date="2022-12-13") - elastic.upload_documents(index, [a]) - d = elastic.get_document(index, "DoccyMcDocface") + upload_documents(index, [a]) + d = get_document(index, "DoccyMcDocface") assert isinstance(d["title"], str) @@ -37,20 +47,20 @@ def test_fields(index): """Can we get the fields from an index""" fields = get_fields(index) assert set(fields.keys()) == {"title", "date", "text", "url"} - assert fields['date']['type'] == "date" + assert fields["date"]["type"] == "date" def test_values(index): """Can we get values for a specific field""" upload(index, [dict(bla=x) for x in ["odd", "even", "even"] * 10], fields={"bla": "keyword"}) - assert set(elastic.get_field_values(index, "bla", 10)) == {"odd", "even"} + assert set(get_field_values(index, "bla", 10)) == {"odd", "even"} def test_update(index_docs): """Can we update a field on a document?""" - assert elastic.get_document(index_docs, '0', _source=['annotations']) == {} - elastic.update_document(index_docs, '0', {'annotations': {'x': 3}}) - assert elastic.get_document(index_docs, '0', _source=['annotations'])['annotations'] == {'x': 3} + assert get_document(index_docs, "0", _source=["annotations"]) == {} + update_document(index_docs, "0", {"annotations": {"x": 3}}) + assert get_document(index_docs, "0", _source=["annotations"])["annotations"] == {"x": 3} def test_add_tag(index_docs): @@ -58,30 +68,32 @@ def q(*ids): return dict(query=dict(ids={"values": ids})) def tags(): - return {doc['_id']: doc['tag'] - for doc in query_documents(index_docs, fields=["tag"]).data - if 'tag' in doc and doc['tag'] is not None} + return { + doc["_id"]: doc["tag"] + for doc in query_documents(index_docs, fields=["tag"]).data + if "tag" in doc and doc["tag"] is not None + } assert tags() == {} - elastic.update_tag_by_query(index_docs, "add", q('0', '1'), "tag", "x") + update_tag_by_query(index_docs, "add", q("0", "1"), "tag", "x") refresh_index(index_docs) - assert tags() == {'0': ['x'], '1': ['x']} - elastic.update_tag_by_query(index_docs, "add", q('1', '2'), "tag", "x") + assert tags() == {"0": ["x"], "1": ["x"]} + update_tag_by_query(index_docs, "add", q("1", "2"), "tag", "x") refresh_index(index_docs) - assert tags() == {'0': ['x'], '1': ['x'], '2': ['x']} - elastic.update_tag_by_query(index_docs, "add", q('2', '3'), "tag", "y") + assert tags() == {"0": ["x"], "1": ["x"], "2": ["x"]} + update_tag_by_query(index_docs, "add", q("2", "3"), "tag", "y") refresh_index(index_docs) - assert tags() == {'0': ['x'], '1': ['x'], '2': ['x', 'y'], '3': ['y']} - elastic.update_tag_by_query(index_docs, "remove", q('0', '2', '3'), "tag", "x") + assert tags() == {"0": ["x"], "1": ["x"], "2": ["x", "y"], "3": ["y"]} + update_tag_by_query(index_docs, "remove", q("0", "2", "3"), "tag", "x") refresh_index(index_docs) - assert tags() == {'1': ['x'], '2': ['y'], '3': ['y']} + assert tags() == {"1": ["x"], "2": ["y"], "3": ["y"]} def test_deduplication(index): doc = {"title": "titel", "text": "text", "date": datetime(2020, 1, 1)} - elastic.upload_documents(index, [doc]) + upload_documents(index, [doc]) refresh_index(index) assert query_documents(index).total_count == 1 - elastic.upload_documents(index, [doc]) + upload_documents(index, [doc]) refresh_index(index) assert query_documents(index).total_count == 1 From aba98d5425159ee5725dd95ca66edf254c394b7d Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Fri, 12 Jan 2024 17:42:39 +0100 Subject: [PATCH 09/80] typinggit add .! --- .vscode/settings.json | 4 +- amcat4/aggregate.py | 17 +- amcat4/api/auth.py | 87 ++++----- amcat4/api/index.py | 35 ++-- amcat4/api/query.py | 405 +++++++++++++++++----------------------- amcat4/api/users.py | 13 +- amcat4/date_mappings.py | 15 +- amcat4/elastic.py | 3 +- amcat4/index.py | 40 ++-- amcat4/models.py | 39 +++- amcat4/query.py | 154 +++++++-------- amcat4/util.py | 25 --- tests/conftest.py | 14 +- tests/test_aggregate.py | 101 ++++++---- tests/test_query.py | 21 +-- 15 files changed, 441 insertions(+), 532 deletions(-) delete mode 100644 amcat4/util.py diff --git a/.vscode/settings.json b/.vscode/settings.json index 862f43b..2b4d815 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -5,5 +5,7 @@ }, "black-formatter.args": ["--line-length", "127"], "mypy.enabled": true, - "mypy.runUsingActiveInterpreter": true + "mypy.runUsingActiveInterpreter": true, + "python.analysis.typeCheckingMode": "basic", + "python.analysis.autoImportCompletions": true } diff --git a/amcat4/aggregate.py b/amcat4/aggregate.py index 1a81653..1ebbec0 100644 --- a/amcat4/aggregate.py +++ b/amcat4/aggregate.py @@ -2,13 +2,13 @@ Aggregate queries """ from datetime import datetime -from typing import Mapping, Iterable, Union, Tuple, Sequence, List, Dict, Optional +from typing import Mapping, Iterable, Union, Tuple, Sequence, List, Dict from amcat4.date_mappings import interval_mapping from amcat4.elastic import es from amcat4.index import get_fields -from amcat4.query import build_body, _normalize_queries -from amcat4.models import Field +from amcat4.query import build_body +from amcat4.models import Field, FilterSpec def _combine_mappings(mappings): @@ -169,8 +169,8 @@ def _elastic_aggregate( def _aggregate_results( index: Union[str, List[str]], axes: List[Axis], - queries: Mapping[str, str], - filters: Optional[Mapping[str, Mapping]], + queries: dict[str, str] | None, + filters: dict[str, FilterSpec] | None, aggregations: List[Aggregation], ) -> Iterable[tuple]: if not axes: @@ -184,6 +184,8 @@ def _aggregate_results( ) yield result["count"], elif any(ax.field == "_query" for ax in axes): + if queries is None: + raise ValueError("Queries must be specified when aggregating by query") # Strip off _query axis and run separate aggregation for each query i = [ax.field for ax in axes].index("_query") _axes = axes[:i] + axes[(i + 1) :] @@ -208,8 +210,8 @@ def query_aggregate( axes: list[Axis] | None = None, aggregations: list[Aggregation] | None = None, *, - queries: Mapping[str, str] | Sequence[str] | None = None, - filters: Mapping[str, Mapping] | None = None, + queries: dict[str, str] | None = None, + filters: dict[str, FilterSpec] | None = None, ) -> AggregateResult: """ Conduct an aggregate query. @@ -247,7 +249,6 @@ def query_aggregate( aggregations = [] for aggregation in aggregations: aggregation.ftype = all_fields[aggregation.field].type - queries = _normalize_queries(queries) data = list(_aggregate_results(index, axes, queries, filters, aggregations)) return AggregateResult( axes, diff --git a/amcat4/api/auth.py b/amcat4/api/auth.py index f2e542a..095d444 100644 --- a/amcat4/api/auth.py +++ b/amcat4/api/auth.py @@ -2,20 +2,17 @@ import functools import logging from datetime import datetime -from typing import Iterable import requests from authlib.common.errors import AuthlibBaseError from authlib.jose import jwt -from fastapi import HTTPException -from fastapi.params import Depends +from fastapi import HTTPException, Depends from fastapi.security import OAuth2PasswordBearer from starlette.status import HTTP_401_UNAUTHORIZED -from amcat4 import elastic -from amcat4.util import parse_field +from amcat4.models import FieldSpec from amcat4.config import get_settings, AuthOptions -from amcat4.index import Role, get_role, get_global_role +from amcat4.index import Role, get_role, get_global_role, get_fields oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/token", auto_error=False) @@ -44,9 +41,7 @@ def verify_token(token: str) -> dict: if payload["exp"] < now: raise InvalidToken("Token expired") if payload["resource"] != get_settings().host: - raise InvalidToken( - f"Wrong host! {payload['resource']} != {get_settings().host}" - ) + raise InvalidToken(f"Wrong host! {payload['resource']} != {get_settings().host}") return payload @@ -77,24 +72,19 @@ def check_global_role(user: str, required_role: Role, raise_error=True): try: global_role = get_global_role(user) except Exception as e: - raise HTTPException( - status_code=500, detail=f"Error on retrieving user {user}: {e}" - ) + raise HTTPException(status_code=500, detail=f"Error on retrieving user {user}: {e}") if global_role and global_role >= required_role: return global_role if raise_error: raise HTTPException( status_code=401, - detail=f"User {user} does not have global " - f"{required_role.name.title()} permissions on this instance", + detail=f"User {user} does not have global " f"{required_role.name.title()} permissions on this instance", ) else: return False -def check_role( - user: str, required_role: Role, index: str, required_global_role: Role = Role.ADMIN -): +def check_role(user: str, required_role: Role, index: str, required_global_role: Role = Role.ADMIN): """Check if the given user have at least the given role (in the index, if given), raise Exception otherwise. :param user: The email address of the authenticated user @@ -115,12 +105,11 @@ def check_role( else: raise HTTPException( status_code=401, - detail=f"User {user} does not have " - f"{required_role.name.title()} permissions on index {index}", + detail=f"User {user} does not have " f"{required_role.name.title()} permissions on index {index}", ) -def check_fields_access(index: str, user: str, fields: Iterable[str]) -> None: +def check_fields_access(index: str, user: str, fields: list[FieldSpec]) -> None: """Check if the given user is allowed to query the given fields and snippets on the given index. :param index: The index to check the role on @@ -142,47 +131,45 @@ def check_fields_access(index: str, user: str, fields: Iterable[str]) -> None: return None # after this, we know the user is a metareader, so we need to check metareader_access - index_fields = elastic.get_fields(index) + index_fields = get_fields(index) for field in fields: - fieldname, nomatch_chars, max_matches, match_chars = parse_field(field) - if fieldname not in index_fields.keys(): + if field.name not in index_fields: + # might be better to raise an error here, but since we want to support querying multiple + # indices at once, this allows the user to query fields that do not exist on all indices continue - meta = index_fields[fieldname].get("meta", {}) - metareader_access = meta.get("metareader_access", None) - if not metareader_access or metareader_access == "none": - raise HTTPException( - status_code=401, - detail=f"METAREADER cannot read {field} on index {index}", - ) - if metareader_access == "read": + metareader = index_fields[field.name].metareader + + if metareader.access == "read": continue - if metareader_access.startswith("snippet"): - ( - _, - allowed_nomatch_chars, - allowed_max_matches, - allowed_match_chars, - ) = parse_field(metareader_access) - max_params = f"{fieldname}[{allowed_nomatch_chars};{allowed_max_matches};{allowed_match_chars}]" - - if nomatch_chars is None: + elif metareader.access == "snippet" and metareader.max_snippet is not None: + if metareader.max_snippet is None: + max_params_msg = "" + else: + max_params_msg = "Can only read snippet with max parameters:" + f"\n- nomatch_chars = {metareader.max_snippet.nomatch_chars}" + f"\n- max_matches = {metareader.max_snippet.max_matches}" + f"\n- match_chars = {metareader.max_snippet.match_chars}" + + if field.snippet is None: + # if snippet is not specified, the whole field is requested raise HTTPException( - status_code=401, - detail=f"METAREADER cannot read {field} on index {index}. " - f"Can only read snippets with parameters: {max_params}", + status_code=401, detail=f"METAREADER cannot read {field} on index {index}. {max_params_msg}" ) - valid_nomatch_chars = nomatch_chars <= allowed_nomatch_chars - valid_max_matches = max_matches <= allowed_max_matches - valid_match_chars = match_chars <= allowed_match_chars - + valid_nomatch_chars = field.snippet.nomatch_chars <= metareader.max_snippet.nomatch_chars + valid_max_matches = field.snippet.max_matches <= metareader.max_snippet.max_matches + valid_match_chars = field.snippet.match_chars <= metareader.max_snippet.match_chars valid = valid_nomatch_chars and valid_max_matches and valid_match_chars if not valid: raise HTTPException( status_code=401, - detail=f"The requested snippet of {fieldname} on index {index} is too long. " - f"max parameters are: {max_params}", + detail=f"The requested snippet of {field.name} on index {index} is too long. {max_params_msg}", ) + else: + raise HTTPException( + status_code=401, + detail=f"METAREADER cannot read {field} on index {index}", + ) async def authenticated_user(token: str = Depends(oauth2_scheme)) -> str: diff --git a/amcat4/api/index.py b/amcat4/api/index.py index e15e7fe..4dc6e00 100644 --- a/amcat4/api/index.py +++ b/amcat4/api/index.py @@ -4,25 +4,23 @@ import elasticsearch from elastic_transport import ApiError -from fastapi import APIRouter, HTTPException, Response, status -from fastapi.params import Body, Depends +from fastapi import APIRouter, HTTPException, Response, status, Depends, Body from pydantic import BaseModel, ConfigDict from amcat4 import index from amcat4.api.auth import authenticated_user, authenticated_writer, check_role from amcat4.api.common import py2dict -from amcat4.index import refresh_index as es_refresh_index from amcat4.index import refresh_system_index, remove_role, set_role from amcat4.models import Field app_index = APIRouter(prefix="/index", tags=["index"]) -RoleType = Literal["ADMIN", "WRITER", "READER", "METAREADER", "admin", "writer", "reader", "metareader"] +RoleType = Literal["ADMIN", "WRITER", "READER", "METAREADER"] @app_index.get("/") -def index_list(current_user=Depends(authenticated_user)): +def index_list(current_user: str = Depends(authenticated_user)): """ List index from this server. @@ -48,7 +46,7 @@ class NewIndex(BaseModel): @app_index.post("/", status_code=status.HTTP_201_CREATED) -def create_index(new_index: NewIndex, current_user=Depends(authenticated_writer)): +def create_index(new_index: NewIndex, current_user: str = Depends(authenticated_writer)): """ Create a new index, setting the current user to admin (owner). @@ -74,18 +72,7 @@ def create_index(new_index: NewIndex, current_user=Depends(authenticated_writer) class ChangeIndex(BaseModel): """Form to update an existing index.""" - guest_role: Literal[ - "ADMIN", - "WRITER", - "READER", - "METAREADER", - "admin", - "writer", - "reader", - "metareader", - "NONE", - "none", - ] | None = "None" + guest_role: Literal["ADMIN", "WRITER", "READER", "METAREADER", "NONE"] | None = "NONE" name: Optional[str] = None description: Optional[str] = None summary_field: Optional[str] = None @@ -243,7 +230,7 @@ def delete_document(ix: str, docid: str, user: str = Depends(authenticated_user) @app_index.get("/{ix}/fields") -def get_fields(ix: str, user=Depends(authenticated_user)): +def get_fields(ix: str, user: str = Depends(authenticated_user)): """ Get the fields (columns) used in this index. @@ -259,7 +246,7 @@ def get_fields(ix: str, user=Depends(authenticated_user)): @app_index.post("/{ix}/fields") -def set_fields(ix: str, fields=dict[str, Field], user=Depends(authenticated_user)): +def set_fields(ix: str, fields=dict[str, Field], user: str = Depends(authenticated_user)): """ Set the field types used in this index. @@ -271,7 +258,7 @@ def set_fields(ix: str, fields=dict[str, Field], user=Depends(authenticated_user @app_index.get("/{ix}/fields/{field}/values") -def get_field_values(ix: str, field: str, user=Depends(authenticated_user)): +def get_field_values(ix: str, field: str, user: str = Depends(authenticated_user)): """ Get unique values for a specific field. Should mainly/only be used for tag fields. Main purpose is to provide a list of values for a dropdown menu. @@ -292,14 +279,14 @@ def get_field_values(ix: str, field: str, user=Depends(authenticated_user)): @app_index.get("/{ix}/fields/{field}/stats") -def get_field_stats(ix: str, field: str, user=Depends(authenticated_user)): +def get_field_stats(ix: str, field: str, user: str = Depends(authenticated_user)): """Get statistics for a specific value. Only works for numeric (incl date) fields.""" check_role(user, index.Role.READER, ix) return index.get_field_stats(ix, field) @app_index.get("/{ix}/users") -def list_index_users(ix: str, user=Depends(authenticated_user)): +def list_index_users(ix: str, user: str = Depends(authenticated_user)): """ List the users in this index. @@ -373,4 +360,4 @@ def remove_index_user(ix: str, email: str, user: str = Depends(authenticated_use @app_index.get("/{ix}/refresh", status_code=status.HTTP_204_NO_CONTENT, response_class=Response) def refresh_index(ix: str): - es_refresh_index(ix) + index.refresh_index(ix) diff --git a/amcat4/api/query.py b/amcat4/api/query.py index fba7a9d..b946072 100644 --- a/amcat4/api/query.py +++ b/amcat4/api/query.py @@ -1,17 +1,16 @@ """API Endpoints for querying.""" -from typing import Dict, List, Optional, Any, Union, Iterable, Tuple, Literal +from typing import Annotated, Dict, List, Optional, Any, Union, Iterable, Literal -from fastapi import APIRouter, HTTPException, status, Request, Query, Depends, Response -from fastapi.params import Body +from fastapi import APIRouter, HTTPException, status, Depends, Response, Body from pydantic.main import BaseModel -from amcat4 import elastic, query, aggregate +from amcat4 import query, aggregate from amcat4.aggregate import Axis, Aggregation from amcat4.api.auth import authenticated_user, check_fields_access -from amcat4.index import Role, get_role +from amcat4.index import Role, get_role, get_fields +from amcat4.models import FieldSpec, FilterSpec, FilterValue, SortSpec from amcat4.query import update_tag_query -from amcat4.util import parse_field app_query = APIRouter(prefix="/index", tags=["query"]) @@ -34,8 +33,8 @@ class QueryResult(BaseModel): def get_or_validate_allowed_fields( - user: str, indices: Iterable[str], fields: Iterable[str] = None -): + user: str, indices: Iterable[str], fields: list[FieldSpec] | None = None +) -> list[FieldSpec]: """ For any endpoint that returns field values, make sure the user only gets fields that they are allowed to see. If fields is None, return all allowed fields. If fields is not None, @@ -56,242 +55,193 @@ def get_or_validate_allowed_fields( # this is error prone and complex, so best to just disallow it. Also, requesting all fields # for multiple indices is probably not something we should support anyway raise ValueError("Fields should be specified if multiple indices are given") - index_fields = elastic.get_fields(indices[0]) + index_fields = get_fields(indices[0]) role = get_role(indices[0], user) - all_fields = [] + allowed_fields = [] for field in index_fields.keys(): if role >= Role.READER: - all_fields.append(field) + allowed_fields.append(field) elif role == Role.METAREADER: - field_meta: dict = index_fields[field].get("meta", {}) - metareader_access = field_meta.get("metareader_access", None) - if metareader_access == "read": - all_fields.append(field) - elif "snippet" in metareader_access: - _, nomatch_chars, max_matches, match_chars = parse_field( - metareader_access - ) - all_fields.append( - f"{field}[{nomatch_chars};{max_matches};{match_chars}]" - ) + metareader = index_fields[field].metareader + if metareader.access == "read": + allowed_fields.append(field) + if metareader.access == "snippet": + allowed_fields.append({"field": field, "snippet": metareader.max_snippet}) else: raise HTTPException( status_code=401, detail=f"User {user} does not have a role on index {indices[0]}", ) + return allowed_fields else: + fieldspecs = [field if isinstance(field, FieldSpec) else FieldSpec(name=field) for field in fields] for index in indices: - check_fields_access(index, user, fields) - return fields - - -@app_query.get("/{index}/documents", response_model=QueryResult) -def get_documents( - index: str, - request: Request, - q: List[str] = Query( - None, - description="Elastic query string. " - "Argument may be repeated for multiple queries (treated as OR)", - ), - sort: str = Query( - None, - description="Comma separated list of fields to sort on", - examples="id,date:desc", - pattern=r"\w+(:desc)?(,\w+(:desc)?)*", - ), - fields: str = Query( - None, - description="Comma separated list of fields to return. " - "You can also request a snippet of a field by appending the suffix [nomatch_chars;max_matches;match_chars]. " - "'matches' here refers to words from text queries. If there is no query, the snippet is the first [nomatch_chars] characters. " - "If there is a query, snippets are returned for up to [max_matches] matches, with each match having [match_chars] " - "characters. If there are multiple matches, they are concatenated with ' ... '.", - pattern=r"\w+(,\w+)*", - ), - highlight: bool = Query(False, description="If true, highlight fields"), - per_page: int = Query(None, description="Number of results per page"), - page: int = Query(None, description="Page to fetch"), - scroll: str = Query( - None, - description="Create a new scroll_id to download all results in subsequent calls", - examples="3m", - ), - scroll_id: str = Query(None, description="Get the next batch from this scroll id"), - user: str = Depends(authenticated_user), -): - """ - List (possibly filtered) documents in this index. - - Any additional GET parameters are interpreted as filters, and can be - field=value for a term query, or field__xxx=value for a range query, with xxx in gte, gt, lte, lt - Note that dates can use relative queries, see elasticsearch 'date math' - In case of conflict between field names and (other) arguments, you may prepend a field name with __ - If your field names contain __, it might be better to use POST queries - - Returns a JSON object {data: [...], meta: {total_count, per_page, page_count, page|scroll_id}} - """ - indices = index.split(",") - fields = fields and fields.split(",") - fields = get_or_validate_allowed_fields(user, indices, fields) - for index in indices: - check_fields_access(index, user, fields) - - args = {} - sort = sort and [ - {x.replace(":desc", ""): "desc"} if x.endswith(":desc") else x - for x in sort.split(",") - ] - known_args = ["page", "per_page", "scroll", "scroll_id", "highlight"] - for name in known_args: - val = locals()[name] - if val: - args[name] = int(val) if name in ["page", "per_page"] else val - filters: Dict[str, Dict] = {} - for f, v in request.query_params.items(): - if f not in known_args + ["fields", "sort", "q"]: - if f.startswith("__"): - f = f[2:] - if "__" in f: # range query - (field, operator) = f.split("__") - if field not in filters: - filters[field] = {} - filters[field][operator] = v - else: # value query - if f not in filters: - filters[f] = {"values": []} - filters[f]["values"].append(v) - r = query.query_documents( - indices, fields=fields, queries=q, filters=filters, sort=sort, **args - ) - if r is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No results") - return r.as_dict() - - -FilterValue = Union[str, int] - - -class FilterSpec(BaseModel): - """Form for filter specification.""" + check_fields_access(index, user, fieldspecs) + return fieldspecs - values: Optional[List[FilterValue]] = None - gt: Optional[FilterValue] = None - lt: Optional[FilterValue] = None - gte: Optional[FilterValue] = None - lte: Optional[FilterValue] = None - exists: Optional[bool] = None - -def _process_queries( - queries: Optional[Union[str, List[str], List[Dict[str, str]]]] = None -) -> Optional[dict]: +def _standardize_queries(queries: str | list[str] | dict[str, str] | None = None) -> dict[str, str] | None: """Convert query json to dict format: {label1:query1, label2: query2} uses indices if no labels given.""" + if queries: # to dict format: {label1:query1, label2: query2} uses indices if no labels given if isinstance(queries, str): - queries = [queries] - if isinstance(queries, list): - queries = {str(i): q for i, q in enumerate(queries)} - return queries + return {"1": queries} + elif isinstance(queries, list): + return {str(i): q for i, q in enumerate(queries)} + elif isinstance(queries, dict): + return queries + return None -def _process_filters( - filters: Optional[ - Dict[str, Union[FilterValue, List[FilterValue], FilterSpec]] - ] = None -) -> Iterable[Tuple[str, dict]]: +def _standardize_filters( + filters: dict[str, FilterValue | list[FilterValue] | FilterSpec] | None = None +) -> dict[str, FilterSpec] | None: """Convert filters to dict format: {field: {values: []}}.""" if not filters: - return + return None + + f = dict() for field, filter_ in filters.items(): if isinstance(filter_, str): - filter_ = [filter_] - if isinstance(filter_, list): - yield field, {"values": filter_} + f[field] = FilterSpec(values=[filter_]) + elif isinstance(filter_, list): + f[field] = FilterSpec(values=filter_) elif isinstance(filter_, FilterSpec): - yield field, { - k: v for (k, v) in filter_.model_dump().items() if v is not None - } + f[field] = filter_ else: raise ValueError(f"Cannot parse filter: {filter_}") + return f + + +def _standardize_fields(fields: list[str | FieldSpec] | None = None) -> list[FieldSpec] | None: + """Convert fields to list of FieldSpecs.""" + if not fields: + return None + + f = [] + for field in fields: + if isinstance(field, str): + f.append(FieldSpec(name=field)) + elif isinstance(field, FieldSpec): + f.append(field) + else: + raise ValueError(f"Cannot parse field: {field}") + return f + + +def _standardize_sort(sort: str | list[str] | list[dict[str, SortSpec]] | None = None) -> list[dict[str, SortSpec]] | None: + """Convert sort to list of dicts.""" + + ## TODO: sort cannot be right. that array around dict is useless + + if not sort: + return None + if isinstance(sort, str): + return [{sort: SortSpec(order="asc")}] + + sortspec: list[dict[str, SortSpec]] = [] + + for field in sort: + if isinstance(field, str): + sortspec.append({field: SortSpec(order="asc")}) + elif isinstance(field, dict): + sortspec.append(field) + else: + raise ValueError(f"Cannot parse sort: {sort}") + + return sortspec @app_query.post("/{index}/query", response_model=QueryResult) def query_documents_post( - index: str, - queries: Optional[Union[str, List[str], Dict[str, str]]] = Body( - None, - description="Query/Queries to run. Value should be a single query string, a list of query strings, " - "or a dict of {'label': 'query'}", - ), - fields: Optional[List[str]] = Body( - None, - description="List of fields to retrieve for each document" - "You can also request a snippet of a field by adding snippet parameters between brackets: " - "fieldname[nomatch_chars;max_matches;match_chars]. 'matches' here refers to words from text queries. " - "If there is no query, the snippet is the first [nomatch_chars] characters. " - "If there is a query, snippets are returned for up to [max_matches] matches, with each match having [match_chars] " - "characters. If there are multiple matches, they are concatenated with ' ... '.", - ), - filters: Optional[ - Dict[str, Union[FilterValue, List[FilterValue], FilterSpec]] - ] = Body( - None, - description="Field filters, should be a dict of field names to filter specifications," - "which can be either a value, a list of values, or a FilterSpec dict", - ), - sort: Optional[Union[str, List[str], List[Dict[str, dict]]]] = Body( - None, - description="Sort by field name(s) or dict (see " - "https://www.elastic.co/guide/en/elasticsearch/reference/current/sort-search-results.html for dict format)", - examples={ - "simple": {"summary": "Sort by single field", "value": "'date'"}, - "multiple": { - "summary": "Sort by multiple fields", - "value": "['date', 'title']", + index: str | list[str], + queries: Annotated[ + str | list[str] | dict[str, str] | None, + Body( + description="Query/Queries to run. Value should be a single query string, a list of query strings, " + "or a dict of {'label': 'query'}", + ), + ] = None, + fields: Annotated[ + list[str | FieldSpec] | None, + Body( + description="List of fields to retrieve for each document" + "In the list you can specify a fieldname, but also a FieldSpec dict." + "Using the FieldSpec allows you to request only a snippet of a field." + "fieldname[nomatch_chars;max_matches;match_chars]. 'matches' here refers to words from text queries. " + "If there is no query, the snippet is the first [nomatch_chars] characters. " + "If there is a query, snippets are returned for up to [max_matches] matches, with each match having [match_chars] " + "characters. If there are multiple matches, they are concatenated with ' ... '.", + openapi_examples={ + "simple": {"summary": "Retrieve single field", "value": '["title", "text", "date"]'}, + "text as snippet": { + "summary": "Retrieve the full title, but text only as snippet", + "value": '["title", {"name": "text", "snippet": {"nomatch_chars": 100}}]', + }, + "all allowed fields": { + "summary": "If fields is left empty, all fields that the user is allowed to see are returned", + }, }, - "dict": { - "summary": "Use dict to specify sort options", - "value": " [{'date': {'order':'desc'}}]", + ), + ] = None, + filters: Annotated[ + dict[str, FilterValue | list[FilterValue] | FilterSpec] | None, + Body( + description="Field filters, should be a dict of field names to filter specifications," + "which can be either a value, a list of values, or a FilterSpec dict", + ), + ] = None, + sort: Annotated[ + str | list[str] | list[dict[str, SortSpec]] | None, + Body( + None, + description="Sort by field name(s) or dict (see " + "https://www.elastic.co/guide/en/elasticsearch/reference/current/sort-search-results.html for dict format)", + openapi_examples={ + "simple": {"summary": "Sort by single field", "value": "'date'"}, + "multiple": { + "summary": "Sort by multiple fields", + "value": "['date', 'title']", + }, + "dict": { + "summary": "Use dict to specify sort options", + "value": " [{'date': {'order':'desc'}}]", + }, }, - }, - ), - per_page: Optional[int] = Body(10, description="Number of documents per page"), - page: Optional[int] = Body(0, description="Which page to retrieve"), - scroll: Optional[str] = Body( - None, - description="Scroll specification (e.g. '5m') to start a scroll request" - "This will return a scroll_id which should be passed to subsequent calls" - "(this is the advised way of scrolling through multiple pages of results)", - examples="5m", - ), - scroll_id: Optional[str] = Body( - None, description="Scroll id from previous response to continue scrolling" - ), - highlight: Optional[bool] = Body(False, description="If true, highlight fields"), - user=Depends(authenticated_user), + ), + ] = None, + per_page: Annotated[int, Body(description="Number of documents per page")] = 10, + page: Annotated[int, Body(description="Which page to retrieve")] = 0, + scroll: Annotated[ + str | None, + Body( + description="Scroll specification (e.g. '5m') to start a scroll request" + "This will return a scroll_id which should be passed to subsequent calls" + "(this is the advised way of scrolling through multiple pages of results)", + examples=["5m"], + ), + ] = None, + scroll_id: Annotated[str | None, Body(description="Scroll id from previous response to continue scrolling")] = None, + highlight: Annotated[bool, Body(description="If true, highlight fields")] = False, + user: str = Depends(authenticated_user), ): """ List or query documents in this index. Returns a JSON object {data: [...], meta: {total_count, per_page, page_count, page|scroll_id}} """ - # TODO check user rights on index - # Standardize fields, queries and filters to their most versatile format - indices = index.split(",") - fields = [fields] if isinstance(fields, str) else fields - fields = get_or_validate_allowed_fields(user, indices, fields) - queries = _process_queries(queries) - filters = dict(_process_filters(filters)) + indices = index if isinstance(index, list) else [index] + fieldspecs = get_or_validate_allowed_fields(user, indices, _standardize_fields(fields)) + r = query.query_documents( indices, - queries=queries, - filters=filters, - fields=fields, - sort=sort, + queries=_standardize_queries(queries), + filters=_standardize_filters(filters), + fields=fieldspecs, + sort=_standardize_sort(sort), per_page=per_page, page=page, scroll_id=scroll_id, @@ -323,25 +273,21 @@ class AxisSpec(BaseModel): @app_query.post("/{index}/aggregate") def query_aggregate_post( index: str, - axes: Optional[List[AxisSpec]] = Body( - None, description="Axes to aggregate on (i.e. group by)" - ), - aggregations: Optional[List[AggregationSpec]] = Body( - None, description="Aggregate functions to compute" - ), + axes: Optional[List[AxisSpec]] = Body(None, description="Axes to aggregate on (i.e. group by)"), + aggregations: Optional[List[AggregationSpec]] = Body(None, description="Aggregate functions to compute"), queries: Optional[Union[str, List[str], Dict[str, str]]] = Body( None, description="Query/Queries to run. Value should be a single query string, a list of query strings, " "or a dict of queries {'label': 'query'}", ), - filters: Optional[ - Dict[str, Union[FilterValue, List[FilterValue], FilterSpec]] - ] = Body( - None, - description="Field filters, should be a dict of field names to filter specifications," - "which can be either a value, a list of values, or a FilterSpec dict", - ), - _user=Depends(authenticated_user), + filters: Annotated[ + dict[str, FilterValue | list[FilterValue] | FilterSpec] | None, + Body( + description="Field filters, should be a dict of field names to filter specifications," + "which can be either a value, a list of values, or a FilterSpec dict", + ), + ] = None, + _user: str = Depends(authenticated_user), ): """ Construct an aggregate query. @@ -357,18 +303,15 @@ def query_aggregate_post( # TODO check user rights on index indices = index.split(",") _axes = [Axis(**x.model_dump()) for x in axes] if axes else [] - _aggregations = ( - [Aggregation(**x.model_dump()) for x in aggregations] if aggregations else [] - ) + _aggregations = [Aggregation(**x.model_dump()) for x in aggregations] if aggregations else [] if not (_axes or _aggregations): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Aggregation needs at least one axis or aggregation", ) - queries = _process_queries(queries) - filters = dict(_process_filters(filters)) + results = aggregate.query_aggregate( - indices, _axes, _aggregations, queries=queries, filters=filters + indices, _axes, _aggregations, queries=_standardize_queries(queries), filters=_standardize_filters(filters) ) return { "meta": { @@ -386,9 +329,7 @@ def query_aggregate_post( ) def query_update_tags( index: str, - action: Literal["add", "remove"] = Body( - None, description="Action (add or remove) on tags" - ), + action: Literal["add", "remove"] = Body(None, description="Action (add or remove) on tags"), field: str = Body(None, description="Tag field to update"), tag: str = Body(None, description="Tag to add or remove"), queries: Optional[Union[str, List[str], Dict[str, str]]] = Body( @@ -396,24 +337,20 @@ def query_update_tags( description="Query/Queries to run. Value should be a single query string, a list of query strings, " "or a dict of {'label': 'query'}", ), - filters: Optional[ - Dict[str, Union[FilterValue, List[FilterValue], FilterSpec]] - ] = Body( + filters: Optional[Dict[str, Union[FilterValue, List[FilterValue], FilterSpec]]] = Body( None, description="Field filters, should be a dict of field names to filter specifications," "which can be either a value, a list of values, or a FilterSpec dict", ), - ids: Optional[Union[str, List[str]]] = Body( - None, description="Document IDs of documents to update" - ), - _user=Depends(authenticated_user), + ids: Optional[Union[str, List[str]]] = Body(None, description="Document IDs of documents to update"), + _user: str = Depends(authenticated_user), ): """ Add or remove tags by query or by id """ indices = index.split(",") - queries = _process_queries(queries) - filters = dict(_process_filters(filters)) + queries = _standardize_queries(queries) + filters = _standardize_filters(filters) if isinstance(ids, (str, int)): ids = [ids] update_tag_query(indices, action, field, tag, queries, filters, ids) diff --git a/amcat4/api/users.py b/amcat4/api/users.py index 4eda374..e88b9aa 100644 --- a/amcat4/api/users.py +++ b/amcat4/api/users.py @@ -83,15 +83,10 @@ def _get_user(email, current_user): @app_users.get("/users", dependencies=[Depends(authenticated_admin)]) def list_global_users(): """List all global users""" - return [ - {"email": email, "role": role.name} - for (email, role) in index.list_global_users().items() - ] + return [{"email": email, "role": role.name} for (email, role) in index.list_global_users().items()] -@app_users.delete( - "/users/{email}", status_code=status.HTTP_204_NO_CONTENT, response_class=Response -) +@app_users.delete("/users/{email}", status_code=status.HTTP_204_NO_CONTENT, response_class=Response) def delete_user(email: EmailStr, current_user: str = Depends(authenticated_user)): """ Delete the given user. @@ -104,9 +99,7 @@ def delete_user(email: EmailStr, current_user: str = Depends(authenticated_user) @app_users.put("/users/{email}") -def modify_user( - email: EmailStr, data: ChangeUserForm, _user=Depends(authenticated_admin) -): +def modify_user(email: EmailStr, data: ChangeUserForm, _user: str = Depends(authenticated_admin)): """ Modify the given user. Only admin can change users. diff --git a/amcat4/date_mappings.py b/amcat4/date_mappings.py index ffbfc40..0dc77c6 100644 --- a/amcat4/date_mappings.py +++ b/amcat4/date_mappings.py @@ -6,10 +6,7 @@ class DateMapping: interval = None def mapping(self, field: str) -> dict: - return {self.fieldname(field): { - "type": self.mapping_type(), - "script": self.mapping_script(field) - }} + return {self.fieldname(field): {"type": self.mapping_type(), "script": self.mapping_script(field)}} def mapping_script(self, field: str) -> str: raise NotImplementedError() @@ -96,10 +93,12 @@ def postprocess(self, value): return int(value) -def interval_mapping(interval: str) -> Optional[DateMapping]: - for m in mappings(): - if m.interval == interval: - return m +def interval_mapping(interval: str | None) -> Optional[DateMapping]: + if interval is not None: + for m in mappings(): + if m.interval == interval: + return m + return None def mappings() -> Iterable[DateMapping]: diff --git a/amcat4/elastic.py b/amcat4/elastic.py index 8ee6908..bb16ffa 100644 --- a/amcat4/elastic.py +++ b/amcat4/elastic.py @@ -6,8 +6,7 @@ - The elasticsearch backend should contain a system index, which will be created if needed - The system index contains a 'document' for each used index containing: {auth: [{email: role}], guest_role: role} -- We define the mappings (field types) based on existing elasticsearch mappings, - but use field metadata to define specific fields. + """ import functools diff --git a/amcat4/index.py b/amcat4/index.py index 45c06f3..08e4f1a 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -28,6 +28,8 @@ - This system index contains a 'document' for each index: {name: "...", description:"...", guest_role: "...", roles: [{email, role}...]} - A special _global document defines the global properties for this instance (name, roles) +- We define the mappings (field types) based on existing elasticsearch mappings, + but use field metadata to define specific fields. """ import collections from enum import IntEnum @@ -75,14 +77,15 @@ } DEFAULT_INDEX_FIELDS = { - "text": Field(type="text", metareader_access=DEFAULT_METAREADER["text"]), - "title": Field(type="text", metareader_access=DEFAULT_METAREADER["text"]), - "date": Field(type="date", metareader_access=DEFAULT_METAREADER["date"]), - "url": Field(type="url", metareader_access=DEFAULT_METAREADER["url"]), + "text": Field(type="text", metareader=DEFAULT_METAREADER["text"]), + "title": Field(type="text", metareader=DEFAULT_METAREADER["text"]), + "date": Field(type="date", metareader=DEFAULT_METAREADER["date"]), + "url": Field(type="url", metareader=DEFAULT_METAREADER["url"]), } class Role(IntEnum): + NONE = 0 METAREADER = 10 READER = 20 WRITER = 30 @@ -294,7 +297,10 @@ def set_fields(index: str, new_fields: dict[str, Field] | dict[str, UpdateField] """ Set the fields settings for this index. - Note that if UpdateField also updates the field type, we need to update the mapping as well. + Note that we're storing fields in two places. We keep all field settings in the system index. + But the index that contains the documents also needs to know what fields there are and + what their (elastic) types are. So whenever fields are added or the type is updated, we + also udpate the index mapping. """ system_index = get_settings().system_index try: @@ -306,19 +312,24 @@ def set_fields(index: str, new_fields: dict[str, Field] | dict[str, UpdateField] fields = _fields_from_elastic(d["_source"].get("fields", {})) for field, new_settings in new_fields.items(): - if new_settings.type is not None: - # if new type is specified, we need to update the index mapping properties - type_mappings[field] = ES_MAPPINGS[new_settings.type] - current = fields.get(field) if current is None: # Create field if new_settings.type is None: raise ValueError(f"Field {field} does not yet exist, and to create a new field you need to specify a type") + type_mappings[field] = ES_MAPPINGS[new_settings.type] fields[field] = Field(**new_settings.model_dump()) else: # Update field + # it is not possible to update elastic field types, but we can change amcat types (see ES_MAPPINGS) + if new_settings.type is not None: + if ES_MAPPINGS[current.type] != ES_MAPPINGS[new_settings.type]: + raise ValueError( + f"Field {field} already exists with type {current.type}, cannot change to {new_settings.type}" + ) + + # set new field settings (amcat type, metareader, etc.) fields[field] = updateField(current, new_settings) es().indices.put_mapping(index=index, properties=type_mappings) @@ -370,7 +381,7 @@ def remove_global_role(email: str): remove_role(index=GLOBAL_ROLES, email=email) -def get_role(index: str, email: str) -> Optional[Role]: +def get_role(index: str, email: str) -> Role: """ Retrieve the role of this user on this index, or the guest role if user has no role Raises a ValueError if the index does not exist @@ -388,7 +399,7 @@ def get_role(index: str, email: str) -> Optional[Role]: if role := roles_dict.get(email): return role if index == GLOBAL_ROLES: - return None + return Role.NONE guest_role: str | None = doc["_source"].get("guest_role", None) if guest_role and guest_role.lower() != "none": @@ -580,7 +591,6 @@ def get_field_values(index: str, field: str, size: int) -> List[str]: def get_field_stats(index: str, field: str) -> List[str]: """ - Get field statistics, such as min, max, avg, etc. :param index: The index :param field: The field name :return: A list of values @@ -590,9 +600,9 @@ def get_field_stats(index: str, field: str) -> List[str]: return r["aggregations"]["facets"] -def update_by_query(index: str, script: str, query: dict, params: dict = None): - script = dict(source=script, lang="painless", params=params or {}) - es().update_by_query(index=index, script=script, **query) +def update_by_query(index: str, script: str, query: dict, params: dict | None = None): + script_dict = dict(source=script, lang="painless", params=params or {}) + es().update_by_query(index=index, script=script_dict, **query) TAG_SCRIPTS = dict( diff --git a/amcat4/models.py b/amcat4/models.py index 2213c28..bfddd36 100644 --- a/amcat4/models.py +++ b/amcat4/models.py @@ -1,5 +1,5 @@ from pydantic import BaseModel -from typing import Literal +from typing import Literal, NewType, Union class SnippetParams(BaseModel): @@ -9,9 +9,9 @@ class SnippetParams(BaseModel): the first [nomatch_chars] of the field. """ - nomatch_chars: int - max_matches: int - match_chars: int + nomatch_chars: int = 0 + max_matches: int = 0 + match_chars: int = 0 class FieldMetareaderAccess(BaseModel): @@ -25,17 +25,44 @@ class Field(BaseModel): """Settings for a field.""" type: str - metareader_access: FieldMetareaderAccess + metareader: FieldMetareaderAccess class UpdateField(BaseModel): """Model for updating a field""" type: str | None = None - metareader_access: FieldMetareaderAccess | None = None + metareader: FieldMetareaderAccess | None = None def updateField(field: Field, update: UpdateField | Field): for key in field.model_fields_set: setattr(field, key, getattr(update, key)) return field + + +FilterValue = str | int + + +class FilterSpec(BaseModel): + """Form for filter specification.""" + + values: list[FilterValue] | None = None + gt: FilterValue | None = None + lt: FilterValue | None = None + gte: FilterValue | None = None + lte: FilterValue | None = None + exists: bool | None = None + + +class FieldSpec(BaseModel): + """Form for field specification.""" + + name: str + snippet: SnippetParams | None = None + + +class SortSpec(BaseModel): + """Form for sort specification.""" + + order: Literal["asc", "desc"] = "asc" diff --git a/amcat4/query.py b/amcat4/query.py index 71e8e33..9fa322c 100644 --- a/amcat4/query.py +++ b/amcat4/query.py @@ -16,20 +16,20 @@ Literal, ) +from amcat4.models import FieldSpec, FilterSpec, SortSpec + from .date_mappings import mappings -from .elastic import es, update_tag_by_query -from amcat4 import elastic -from amcat4.util import parse_field -from amcat4.index import Role, get_role +from .elastic import es +from amcat4.index import update_tag_by_query def build_body( - queries: Iterable[str] = None, - filters: Mapping = None, - highlight: dict = None, - ids: Iterable[str] = None, + queries: dict[str, str] | None = None, + filters: dict[str, FilterSpec] | None = None, + highlight: dict | None = None, + ids: list[str] | None = None, ): - def parse_filter(field, filter) -> Tuple[Mapping, Mapping]: + def parse_filter(field, filter) -> Tuple[dict, dict]: filter = filter.copy() extra_runtime_mappings = {} field_filters = [] @@ -41,9 +41,7 @@ def parse_filter(field, filter) -> Tuple[Mapping, Mapping]: if filter.pop("exists"): field_filters.append({"exists": {"field": field}}) else: - field_filters.append( - {"bool": {"must_not": {"exists": {"field": field}}}} - ) + field_filters.append({"bool": {"must_not": {"exists": {"field": field}}}}) for mapping in mappings(): if mapping.interval in filter: value = filter.pop(mapping.interval) @@ -62,7 +60,8 @@ def parse_filter(field, filter) -> Tuple[Mapping, Mapping]: def parse_query(q: str) -> dict: return {"query_string": {"query": q}} - def parse_queries(qs: Sequence[str]) -> dict: + def parse_queries(queries: dict[str, str]) -> dict: + qs = queries.values() if len(qs) == 1: return parse_query(list(qs)[0]) else: @@ -78,10 +77,8 @@ def parse_queries(qs: Sequence[str]) -> dict: fs.append(filter_term) if extra_runtime_mappings: runtime_mappings.update(extra_runtime_mappings) - if queries: - if isinstance(queries, dict): - queries = queries.values() - fs.append(parse_queries(list(queries))) + if queries is not None: + fs.append(parse_queries(queries)) if ids: fs.append({"ids": {"values": list(ids)}}) body: Dict[str, Any] = {"query": {"bool": {"filter": fs}}} @@ -97,12 +94,12 @@ def parse_queries(qs: Sequence[str]) -> dict: class QueryResult: def __init__( self, - data: List[dict], - n: int = None, - per_page: int = None, - page: int = None, - page_count: int = None, - scroll_id: str = None, + data: list[dict], + n: int | None = None, + per_page: int | None = None, + page: int | None = None, + page_count: int | None = None, + scroll_id: str | None = None, ): if n and (page_count is None) and (per_page is not None): page_count = ceil(n / per_page) @@ -113,8 +110,8 @@ def __init__( self.per_page = per_page self.scroll_id = scroll_id - def as_dict(self): - meta = { + def as_dict(self) -> dict: + meta: dict[str, int | str | None] = { "total_count": self.total_count, "per_page": self.per_page, "page_count": self.page_count, @@ -126,31 +123,20 @@ def as_dict(self): return dict(meta=meta, results=self.data) -def _normalize_queries( - queries: Optional[Union[Dict[str, str], Iterable[str]]] -) -> Mapping[str, str]: - if queries is None: - return {} - if isinstance(queries, dict): - return queries - return {q: q for q in queries} - - def query_documents( index: Union[str, Sequence[str]], - queries: Union[Mapping[str, str], Iterable[str]] = None, + fields: list[FieldSpec], + queries: dict[str, str] | None = None, + filters: dict[str, FilterSpec] | None = None, + sort: list[dict[str, SortSpec]] | None = None, *, page: int = 0, per_page: int = 10, scroll=None, - scroll_id: str = None, - fields: Iterable[str] = None, - snippets: Iterable[str] = None, - filters: Mapping[str, Mapping] = None, - highlight: Literal["none", "text", "snippets"] = "none", - sort: List[Union[str, Mapping]] = None, + scroll_id: str | None = None, + highlight: bool = False, **kwargs, -) -> Optional[QueryResult]: +) -> QueryResult | None: """ Conduct a query_string query, returning the found documents. @@ -159,18 +145,16 @@ def query_documents( If the scroll parameter is given, the result will contain a scroll_id which can be used to get the next batch. In case there are no more documents to scroll, it will return None :param index: The name of the index or indexes - :param queries: a list of queries OR a dict {label1: query1, ...} + :param fields: List of fields using the FieldSpec syntax. We enforce specific field selection here. Any logic + for determining whether a user can see the field should be done in the API layer. + :param queries: if not None, a dict with labels and queries {label1: query1, ...} + :param filters: if not None, a dict where the key is the field and the value is a FilterSpec + :param page: The number of the page to request (starting from zero) :param per_page: The number of hits per page :param scroll: if not None, will create a scroll request rather than a paginated request. Parmeter should specify the time the context should be kept alive, or True to get the default of 2m. :param scroll_id: if not None, should be a previously returned context_id to retrieve a new page of results - :param fields: if not None, specify a list of fields to retrieve for each hit - :param filters: if not None, a dict of filters with either value, values, or gte/gt/lte/lt ranges: - {field: {'values': [value1,value2], - 'value': value, - 'gte/gt/lte/lt': value, - ...}} :param highlight: if True, add tags to query matches in fields :param sort: Sort order of results, can be either a single field or a list of fields. In the list, each field is a string or a dict with options, e.g. ["id", {"date": {"order": "desc"}}] @@ -184,7 +168,7 @@ def query_documents( if scroll or scroll_id: # set scroll to default also if scroll_id is given but no scroll time is known kwargs["scroll"] = "2m" if (not scroll or scroll is True) else scroll - queries = _normalize_queries(queries) + if sort is not None: kwargs["sort"] = sort if scroll_id: @@ -193,7 +177,7 @@ def query_documents( return None else: h = query_highlight(fields, highlight) - body = build_body(queries.values(), filters, h) + body = build_body(queries, filters, h) if fields: fields = fields if isinstance(fields, list) else list(fields) @@ -212,9 +196,7 @@ def query_documents( hitdict[key] = " ... ".join(hit["highlight"][key]) data.append(hitdict) if scroll_id: - return QueryResult( - data, n=result["hits"]["total"]["value"], scroll_id=result["_scroll_id"] - ) + return QueryResult(data, n=result["hits"]["total"]["value"], scroll_id=result["_scroll_id"]) elif scroll: return QueryResult( data, @@ -223,49 +205,41 @@ def query_documents( scroll_id=result["_scroll_id"], ) else: - return QueryResult( - data, n=result["hits"]["total"]["value"], per_page=per_page, page=page - ) + return QueryResult(data, n=result["hits"]["total"]["value"], per_page=per_page, page=page) -def query_highlight(fields: Iterable[str] = None, highlight_queries: bool = False): +def query_highlight(fields: list[FieldSpec], highlight_queries: bool = False) -> dict[str, Any]: """ The elastic "highlight" parameters works for both highlighting text fields and adding snippets. This function will return the highlight parameter to be added to the query body. """ - highlight = { + highlight: dict[str, Any] = { # "pre_tags": [""] if highlight is True else [""], # "post_tags": [""] if highlight is True else [""], "require_field_match": True, + "fields": {}, } - if fields is None: - if highlight_queries is True: - highlight["fields"]["*"] = {"number_of_fragments": 0} - else: - highlight["fields"] = {} - for field in fields: - fieldname, nomatch_chars, max_matches, match_chars = parse_field(field) - if nomatch_chars is None: - if highlight_queries is True: - # This will overwrite the field with the highlighted version, so - # only needed if highlight is True - highlight["fields"][fieldname] = {"number_of_fragments": 0} - else: - # the elastic highlight feature is also used to get snippets. note that - # above in the - highlight["fields"][fieldname] = { - "no_match_size": nomatch_chars, - "number_of_fragments": max_matches, - "fragment_size": match_chars, - } - if highlight_queries is False or max_matches == 0: - # This overwrites the actual query, so that the highlights are not returned. - # Also used to get the nomatch snippet if max_matches = 0 - highlight["fields"][fieldname]["highlight_query"] = { - "match_all": {} - } + for field in fields: + if field.snippet is None: + if highlight_queries is True: + # This will overwrite the field with the highlighted version, so + # only needed if highlight is True + highlight["fields"][field.name] = {"number_of_fragments": 0} + else: + # the elastic highlight feature is also used to get snippets. note that + # above in the + highlight["fields"][field.name] = { + "no_match_size": field.snippet.nomatch_chars, + "number_of_fragments": field.snippet.max_matches, + "fragment_size": field.snippet.match_chars, + } + if highlight_queries is False or field.snippet.max_matches == 0: + # This overwrites the actual query, so that the highlights are not returned. + # Also, if max_matches is zero, we drop the query for highlighting so that + # the nomatch_chars are returned + highlight["fields"][field.name]["highlight_query"] = {"match_all": {}} return highlight @@ -284,13 +258,13 @@ def overwrite_highlight_results(hit: dict, hitdict: dict): def update_tag_query( - index: Union[str, Sequence[str]], + index: str | list[str], action: Literal["add", "remove"], field: str, tag: str, - queries: Union[Mapping[str, str], Iterable[str]] = None, - filters: Mapping[str, Mapping] = None, - ids: Sequence[str] = None, + queries: dict[str, str] | list[str] | None = None, + filters: dict[str, dict] | None = None, + ids: list[str] | None = None, ): """Add or remove tags using a query""" body = build_body(queries and queries.values(), filters, ids=ids) diff --git a/amcat4/util.py b/amcat4/util.py deleted file mode 100644 index fbefe23..0000000 --- a/amcat4/util.py +++ /dev/null @@ -1,25 +0,0 @@ -import re -from typing import Tuple - - -def parse_field(field: str) -> Tuple[str, int, int, int]: - """ - Parse a field into a field and the snippet parameters. - The format is fieldname[nomatch_chars;max_matches;match_chars]. - If no snippet parameters are given, the values are None - """ - pattern = r"\[([0-9]+);([0-9]+);([0-9]+)]$" - match = re.search(pattern, field) - - if match: - fieldname = field[: match.start()] - nomatch_chars = int(match.group(1)) - max_matches = int(match.group(2)) - match_chars = int(match.group(3)) - else: - fieldname = field - nomatch_chars = None - max_matches = None - match_chars = None - - return fieldname, nomatch_chars, max_matches, match_chars diff --git a/tests/conftest.py b/tests/conftest.py index bd651af..d3e2a43 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,7 @@ import responses from fastapi.testclient import TestClient -from amcat4 import elastic, api # noqa: E402 +from amcat4 import api # noqa: E402 from amcat4.config import get_settings, AuthOptions from amcat4.elastic import es from amcat4.index import ( @@ -15,6 +15,7 @@ delete_user, remove_global_role, set_global_role, + upload_documents, ) from tests.middlecat_keypair import PUBLIC_KEY @@ -33,9 +34,7 @@ def mock_middlecat(): get_settings().middlecat_url = "http://localhost:5000" get_settings().host = "http://localhost:3000" with responses.RequestsMock(assert_all_requests_are_fired=False) as resp: - resp.get( - "http://localhost:5000/api/configuration", json={"public_key": PUBLIC_KEY} - ) + resp.get("http://localhost:5000/api/configuration", json={"public_key": PUBLIC_KEY}) yield None @@ -149,7 +148,7 @@ def upload(index: str, docs: Iterable[dict], **kwargs): for k, v in defaults.items(): if k not in doc: doc[k] = v - elastic.upload_documents(index, docs, **kwargs) + upload_documents(index, docs, **kwargs) refresh_index(index) return ids @@ -214,10 +213,7 @@ def index_many(): create_index(index, guest_role=Role.READER) upload( index, - [ - dict(id=i, pagenr=abs(10 - i), text=text) - for (i, text) in enumerate(["odd", "even"] * 10) - ], + [dict(id=i, pagenr=abs(10 - i), text=text) for (i, text) in enumerate(["odd", "even"] * 10)], ) yield index delete_index(index, ignore_missing=True) diff --git a/tests/test_aggregate.py b/tests/test_aggregate.py index 8d94590..739137e 100644 --- a/tests/test_aggregate.py +++ b/tests/test_aggregate.py @@ -6,13 +6,15 @@ from tests.tools import dictset -def do_query(index: str, *axes, **kargs): +def do_query(index: str, *args, **kwargs): def _key(x): if len(x) == 1: return x[0] return x - axes = [Axis(x) if isinstance(x, str) else x for x in axes] - result = query_aggregate(index, axes, **kargs) + + axes = [Axis(x) if isinstance(x, str) else x for x in args] + + result = query_aggregate(index, axes, **kwargs) return {_key(vals[:-1]): vals[-1] for vals in result.data} @@ -27,29 +29,27 @@ def _y(y): def test_aggregate(index_docs): q = functools.partial(do_query, index_docs) assert q(Axis("cat")) == {"a": 3, "b": 1} - assert q(Axis(field="date")) == {_d('2018-01-01'): 2, _d('2018-02-01'): 1, _d('2020-01-01'): 1} + assert q(Axis(field="date")) == {_d("2018-01-01"): 2, _d("2018-02-01"): 1, _d("2020-01-01"): 1} def test_aggregate_querystring(index_docs): q = functools.partial(do_query, index_docs) - assert q("cat", queries=['toto']) == {"a": 1, "b": 1} - assert q("cat", queries=['test*']) == {"a": 2, "b": 1} - assert q("cat", queries=['"a text"', 'another']) == {"a": 2} + assert q("cat", queries=["toto"]) == {"a": 1, "b": 1} + assert q("cat", queries=["test*"]) == {"a": 2, "b": 1} + assert q("cat", queries=['"a text"', "another"]) == {"a": 2} def test_interval(index_docs): q = functools.partial(do_query, index_docs) assert q(Axis(field="date", interval="year")) == {_y(2018): 3, _y(2020): 1} - assert q(Axis(field="i", interval="10")) == {0.: 2, 10.: 1, 30.: 1} + assert q(Axis(field="i", interval="10")) == {0.0: 2, 10.0: 1, 30.0: 1} def test_second_axis(index_docs): q = functools.partial(do_query, index_docs) - assert q("cat", 'subcat') == {("a", "x"): 2, ("a", "y"): 1, ("b", "y"): 1} - assert (q(Axis(field="date", interval="year"), 'cat') - == {(_y(2018), "a"): 2, (_y(2020), "a"): 1, (_y(2018), "b"): 1}) - assert (q('cat', Axis(field="date", interval="year")) - == {("a", _y(2018)): 2, ("a", _y(2020)): 1, ("b", _y(2018)): 1}) + assert q("cat", "subcat") == {("a", "x"): 2, ("a", "y"): 1, ("b", "y"): 1} + assert q(Axis(field="date", interval="year"), "cat") == {(_y(2018), "a"): 2, (_y(2020), "a"): 1, (_y(2018), "b"): 1} + assert q("cat", Axis(field="date", interval="year")) == {("a", _y(2018)): 2, ("a", _y(2020)): 1, ("b", _y(2018)): 1} def test_count(index_docs): @@ -61,46 +61,65 @@ def test_count(index_docs): def test_byquery(index_docs): """Get number of documents per query""" assert do_query(index_docs, Axis("_query"), queries=["text", "test*"]) == {"text": 2, "test*": 3} - assert (do_query(index_docs, Axis("_query"), Axis("subcat"), queries=["text", "test*"]) == - {("text", "x"): 2, ("test*", "x"): 1, ("test*", "y"): 2}) - assert (do_query(index_docs, Axis("subcat"), Axis("_query"), queries=["text", "test*"]) == - {("x", "text"): 2, ("x", "test*"): 1, ("y", "test*"): 2}) + assert do_query(index_docs, Axis("_query"), Axis("subcat"), queries=["text", "test*"]) == { + ("text", "x"): 2, + ("test*", "x"): 1, + ("test*", "y"): 2, + } + assert do_query(index_docs, Axis("subcat"), Axis("_query"), queries=["text", "test*"]) == { + ("x", "text"): 2, + ("x", "test*"): 1, + ("y", "test*"): 2, + } def test_metric(index_docs: str): """Do metric aggregations (e.g. avg(x)) work?""" + # Single and double aggregation with axis def q(axes, aggregations): return dictset(query_aggregate(index_docs, axes, aggregations).as_dicts()) - assert (q([Axis("subcat")], [Aggregation("i", "avg")]) == - dictset([{"subcat": "x", "n": 2, "avg_i": 1.5}, {"subcat": "y", "n": 2, "avg_i": 21.0}])) - assert (q([Axis("subcat")], [Aggregation("i", "avg"), Aggregation("i", "max")]) == - dictset([{"subcat": "x", "n": 2, "avg_i": 1.5, "max_i": 2.0}, - {"subcat": "y", "n": 2, "avg_i": 21.0, "max_i": 31.0}])) + + assert q([Axis("subcat")], [Aggregation("i", "avg")]) == dictset( + [{"subcat": "x", "n": 2, "avg_i": 1.5}, {"subcat": "y", "n": 2, "avg_i": 21.0}] + ) + assert q([Axis("subcat")], [Aggregation("i", "avg"), Aggregation("i", "max")]) == dictset( + [{"subcat": "x", "n": 2, "avg_i": 1.5, "max_i": 2.0}, {"subcat": "y", "n": 2, "avg_i": 21.0, "max_i": 31.0}] + ) # Aggregation only - assert (q(None, [Aggregation("i", "avg")]) == dictset([{"n": 4, "avg_i": 11.25}])) - assert (q(None, [Aggregation("i", "avg"), Aggregation("i", "max")]) == - dictset([{"n": 4, "avg_i": 11.25, "max_i": 31.0}])) + assert q(None, [Aggregation("i", "avg")]) == dictset([{"n": 4, "avg_i": 11.25}]) + assert q(None, [Aggregation("i", "avg"), Aggregation("i", "max")]) == dictset([{"n": 4, "avg_i": 11.25, "max_i": 31.0}]) # Check value handling - Aggregation on date fields - assert (q(None, [Aggregation("date", "max")]) == dictset([{"n": 4, "max_date": "2020-01-01T00:00:00"}])) - assert (q([Axis("subcat")], [Aggregation("date", "avg")]) == - dictset([{"subcat": "x", "n": 2, "avg_date": "2018-01-16T12:00:00"}, - {"subcat": "y", "n": 2, "avg_date": "2019-01-01T00:00:00"}])) + assert q(None, [Aggregation("date", "max")]) == dictset([{"n": 4, "max_date": "2020-01-01T00:00:00"}]) + assert q([Axis("subcat")], [Aggregation("date", "avg")]) == dictset( + [ + {"subcat": "x", "n": 2, "avg_date": "2018-01-16T12:00:00"}, + {"subcat": "y", "n": 2, "avg_date": "2019-01-01T00:00:00"}, + ] + ) def test_aggregate_datefunctions(index: str): q = functools.partial(do_query, index) - docs = [dict(date=x) for x in ["2018-01-01T04:00:00", # monday night - "2018-01-01T09:00:00", # monday morning - "2018-01-11T09:00:00", # thursday morning - "2018-01-17T11:00:00", # wednesday morning - "2018-01-17T18:00:00", # wednesday evening - "2018-03-07T23:59:00", # wednesday evening - ]] + docs = [ + dict(date=x) + for x in [ + "2018-01-01T04:00:00", # monday night + "2018-01-01T09:00:00", # monday morning + "2018-01-11T09:00:00", # thursday morning + "2018-01-17T11:00:00", # wednesday morning + "2018-01-17T18:00:00", # wednesday evening + "2018-03-07T23:59:00", # wednesday evening + ] + ] upload(index, docs) - assert q(Axis("date", interval="day")) == {date(2018, 1, 1): 2, date(2018, 1, 11): 1, - date(2018, 1, 17): 2, date(2018, 3, 7): 1} + assert q(Axis("date", interval="day")) == { + date(2018, 1, 1): 2, + date(2018, 1, 11): 1, + date(2018, 1, 17): 2, + date(2018, 3, 7): 1, + } assert q(Axis("date", interval="dayofweek")) == {"Monday": 2, "Wednesday": 3, "Thursday": 1} assert q(Axis("date", interval="daypart")) == {"Night": 1, "Morning": 3, "Evening": 2} assert q(Axis("date", interval="monthnr")) == {1: 5, 3: 1} @@ -108,4 +127,8 @@ def test_aggregate_datefunctions(index: str): assert q(Axis("date", interval="dayofmonth")) == {1: 2, 11: 1, 17: 2, 7: 1} assert q(Axis("date", interval="weeknr")) == {1: 2, 2: 1, 3: 2, 10: 1} assert q(Axis("date", interval="month"), Axis("date", interval="dayofmonth")) == { - (date(2018, 1, 1), 1): 2, (date(2018, 1, 1), 11): 1, (date(2018, 1, 1), 17): 2, (date(2018, 3, 1), 7): 1} + (date(2018, 1, 1), 1): 2, + (date(2018, 1, 1), 11): 1, + (date(2018, 1, 1), 17): 2, + (date(2018, 3, 1), 7): 1, + } diff --git a/tests/test_query.py b/tests/test_query.py index 3669613..d8d4b39 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -10,6 +10,8 @@ def query_ids(index: str, q: Optional[str] = None, **kwargs) -> Set[int]: if q is not None: kwargs["queries"] = [q] res = query.query_documents(index, **kwargs) + if res is None: + return set() return {int(h["_id"]) for h in res.data} @@ -36,6 +38,7 @@ def test_range_query(index_docs): def test_fields(index_docs): res = query.query_documents(index_docs, queries=["test"], fields=["cat", "title"]) + assert res is not None assert set(res.data[0].keys()) == {"cat", "title", "_id"} @@ -43,20 +46,14 @@ def test_highlight(index): words = "The error of regarding functional notions is not quite equivalent to" text = f"{words} a test document. {words} other text documents. {words} you!" upload(index, [dict(title="Een test titel", text=text)]) - res = query.query_documents( - index, fields=["title", "text"], queries=["te*"], highlight=True - ) + res = query.query_documents(index, fields=["title", "text"], queries=["te*"], highlight=True) + assert res is not None doc = res.data[0] assert doc["title"] == "Een test titel" - assert ( - doc["text"] - == f"{words} a test document. {words} other text documents. {words} you!" - ) + assert doc["text"] == f"{words} a test document. {words} other text documents. {words} you!" # snippets can also have highlights - doc = query.query_documents( - index, queries=["te*"], fields=["title"], snippets=["text"], highlight=True - ).data[0] + doc = query.query_documents(index, queries=["te*"], fields=["title"], snippets=["text"], highlight=True).data[0] assert doc["title"] == "Een test titel" assert " a test" in doc["text"] assert " ... " in doc["text"] @@ -64,7 +61,9 @@ def test_highlight(index): def test_query_multiple_index(index_docs, index): upload(index, [{"text": "also a text", "i": -1}]) - assert len(query.query_documents([index_docs, index]).data) == 5 + docs = query.query_documents([index_docs, index]) + assert docs is not None + assert len(docs.data) == 5 def test_query_filter_mapping(index_docs): From 3e820ade86fc3289a361b3c370f56aeb8e4f3947 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Sat, 13 Jan 2024 16:29:38 +0100 Subject: [PATCH 10/80] cleaaaning --- .github/workflows/linting.yml | 33 +++++++------ .vscode/settings.json | 4 +- amcat4/aggregate.py | 4 +- amcat4/api/index.py | 75 ++++++++++++++++------------- amcat4/api/query.py | 67 +++++++++++++------------- amcat4/api/users.py | 15 +++--- amcat4/index.py | 88 ++++++++++++++++++++++++++++------- amcat4/models.py | 13 +++++- amcat4/query.py | 7 +-- setup.py | 10 +--- 10 files changed, 195 insertions(+), 121 deletions(-) diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index 7bd5ca3..3985326 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -5,13 +5,12 @@ name: Flake8 on: push: - branches: [ master ] + branches: [master] pull_request: - branches: [ master ] + branches: [master] jobs: build: - runs-on: ubuntu-latest strategy: fail-fast: false @@ -19,17 +18,17 @@ jobs: python-version: ["3.8", "3.9"] steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - pip install -e .[dev] - - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude=env - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --max-line-length=127 --statistics --exclude=env + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + pip install -e .[dev] + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude=env + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --max-line-length=127 --ignore=E203 --statistics --exclude=env diff --git a/.vscode/settings.json b/.vscode/settings.json index 2b4d815..2e83fb4 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -3,9 +3,11 @@ "editor.defaultFormatter": "ms-python.black-formatter", "editor.formatOnSave": true }, + "black-formatter.args": ["--line-length", "127"], "mypy.enabled": true, "mypy.runUsingActiveInterpreter": true, "python.analysis.typeCheckingMode": "basic", - "python.analysis.autoImportCompletions": true + "python.analysis.autoImportCompletions": true, + "flake8.args": ["--max-line-length=127", "--ignore=E203"] } diff --git a/amcat4/aggregate.py b/amcat4/aggregate.py index 1ebbec0..f881dc8 100644 --- a/amcat4/aggregate.py +++ b/amcat4/aggregate.py @@ -74,7 +74,7 @@ class Aggregation: Specification of a single aggregation, that is, field and aggregation function """ - def __init__(self, field: str, function: str, name: str = None, ftype: str = None): + def __init__(self, field: str, function: str, name: str | None = None, ftype: str | None = None): self.field = field self.function = function self.name = name or f"{function}_{field}" @@ -132,7 +132,7 @@ def _elastic_aggregate( queries, filters, aggregations: list[Aggregation], - runtime_mappings: dict[str, Mapping] = None, + runtime_mappings: dict[str, Mapping] | None = None, after_key=None, ) -> Iterable[dict]: """ diff --git a/amcat4/api/index.py b/amcat4/api/index.py index 4dc6e00..69aaf39 100644 --- a/amcat4/api/index.py +++ b/amcat4/api/index.py @@ -1,24 +1,37 @@ """API Endpoints for document and index management.""" from http import HTTPStatus -from typing import List, Literal, Mapping, Optional +from typing import Annotated, Literal import elasticsearch from elastic_transport import ApiError from fastapi import APIRouter, HTTPException, Response, status, Depends, Body -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel from amcat4 import index from amcat4.api.auth import authenticated_user, authenticated_writer, check_role -from amcat4.api.common import py2dict from amcat4.index import refresh_system_index, remove_role, set_role -from amcat4.models import Field +from amcat4.models import Document, UpdateField app_index = APIRouter(prefix="/index", tags=["index"]) RoleType = Literal["ADMIN", "WRITER", "READER", "METAREADER"] +def _standardize_updatefields(fields: dict[str, str | UpdateField]) -> dict[str, UpdateField]: + standardized_fields: dict[str, UpdateField] = {} + + for name, field in fields.items(): + if isinstance(field, UpdateField): + standardized_fields[name] = field + elif isinstance(field, str): + standardized_fields[name] = UpdateField(type=field) + else: + raise ValueError(f"Cannot parse field: {field}") + + return standardized_fields + + @app_index.get("/") def index_list(current_user: str = Depends(authenticated_user)): """ @@ -40,9 +53,9 @@ class NewIndex(BaseModel): """Form to create a new index.""" id: str - guest_role: Optional[RoleType] = None - name: Optional[str] = None - description: Optional[str] = None + guest_role: RoleType | None = None + name: str | None = None + description: str | None = None @app_index.post("/", status_code=status.HTTP_201_CREATED) @@ -73,9 +86,9 @@ class ChangeIndex(BaseModel): """Form to update an existing index.""" guest_role: Literal["ADMIN", "WRITER", "READER", "METAREADER", "NONE"] | None = "NONE" - name: Optional[str] = None - description: Optional[str] = None - summary_field: Optional[str] = None + name: str | None = None + description: str | None = None + summary_field: str | None = None @app_index.put("/{ix}") @@ -128,21 +141,13 @@ def delete_index(ix: str, user: str = Depends(authenticated_user)): index.delete_index(ix) -class Document(BaseModel): - """Form to create (upload) a new document.""" - - title: str - date: str - text: str - url: Optional[str] = None - model_config = ConfigDict(extra="allow") - - @app_index.post("/{ix}/documents", status_code=status.HTTP_201_CREATED) def upload_documents( ix: str, - documents: List[Document] = Body(None, description="The documents to upload"), - columns: Optional[Mapping[str, str]] = Body(None, description="Optional Specification of field (column) types"), + documents: Annotated[list[Document], Body(description="The documents to upload")], + fields: Annotated[ + dict[str, str | UpdateField] | None, Body(description="Optional Specification of field (column) types") + ] = None, user: str = Depends(authenticated_user), ): """ @@ -156,15 +161,18 @@ def upload_documents( Returns a list of ids for the uploaded documents """ check_role(user, index.Role.WRITER, ix) - documents = [py2dict(doc) for doc in documents] - return index.upload_documents(ix, documents, columns) + + if fields is None: + return index.upload_documents(ix, documents) + else: + return index.upload_documents(ix, documents, _standardize_updatefields(fields)) @app_index.get("/{ix}/documents/{docid}") def get_document( ix: str, docid: str, - fields: Optional[str] = None, + fields: str | None = None, user: str = Depends(authenticated_user), ): """ @@ -237,23 +245,24 @@ def get_fields(ix: str, user: str = Depends(authenticated_user)): Returns a json array of {name, type} objects """ check_role(user, index.Role.METAREADER, ix) + if "," in ix: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"/[index]/fields does not support multiple indices", - ) - return index.get_fields(ix) + return index.get_fields(ix.split(",")) + else: + return index.get_fields(ix) @app_index.post("/{ix}/fields") -def set_fields(ix: str, fields=dict[str, Field], user: str = Depends(authenticated_user)): +def set_fields( + ix: str, fields: Annotated[dict[str, str | UpdateField], Body(description="")], user: str = Depends(authenticated_user) +): """ Set the field types used in this index. - POST body should be a dict of {field: type} or {field: {type: type, meta: meta}} """ check_role(user, index.Role.WRITER, ix) - index.set_fields(ix, fields) + + index.set_fields(ix, _standardize_updatefields(fields)) return "", HTTPStatus.NO_CONTENT diff --git a/amcat4/api/query.py b/amcat4/api/query.py index b946072..b79895c 100644 --- a/amcat4/api/query.py +++ b/amcat4/api/query.py @@ -57,16 +57,16 @@ def get_or_validate_allowed_fields( raise ValueError("Fields should be specified if multiple indices are given") index_fields = get_fields(indices[0]) role = get_role(indices[0], user) - allowed_fields = [] + allowed_fields: list[FieldSpec] = [] for field in index_fields.keys(): if role >= Role.READER: - allowed_fields.append(field) + allowed_fields.append(FieldSpec(name=field)) elif role == Role.METAREADER: metareader = index_fields[field].metareader if metareader.access == "read": - allowed_fields.append(field) + allowed_fields.append(FieldSpec(name=field)) if metareader.access == "snippet": - allowed_fields.append({"field": field, "snippet": metareader.max_snippet}) + allowed_fields.append(FieldSpec(name=field, snippet=metareader.max_snippet)) else: raise HTTPException( status_code=401, @@ -74,11 +74,9 @@ def get_or_validate_allowed_fields( ) return allowed_fields - else: - fieldspecs = [field if isinstance(field, FieldSpec) else FieldSpec(name=field) for field in fields] - for index in indices: - check_fields_access(index, user, fieldspecs) - return fieldspecs + for index in indices: + check_fields_access(index, user, fields) + return fields def _standardize_queries(queries: str | list[str] | dict[str, str] | None = None) -> dict[str, str] | None: @@ -115,7 +113,7 @@ def _standardize_filters( return f -def _standardize_fields(fields: list[str | FieldSpec] | None = None) -> list[FieldSpec] | None: +def _standardize_fieldspecs(fields: list[str | FieldSpec] | None = None) -> list[FieldSpec] | None: """Convert fields to list of FieldSpecs.""" if not fields: return None @@ -134,7 +132,7 @@ def _standardize_fields(fields: list[str | FieldSpec] | None = None) -> list[Fie def _standardize_sort(sort: str | list[str] | list[dict[str, SortSpec]] | None = None) -> list[dict[str, SortSpec]] | None: """Convert sort to list of dicts.""" - ## TODO: sort cannot be right. that array around dict is useless + # TODO: sort cannot be right. that array around dict is useless if not sort: return None @@ -234,7 +232,7 @@ def query_documents_post( """ indices = index if isinstance(index, list) else [index] - fieldspecs = get_or_validate_allowed_fields(user, indices, _standardize_fields(fields)) + fieldspecs = get_or_validate_allowed_fields(user, indices, _standardize_fieldspecs(fields)) r = query.query_documents( indices, @@ -275,11 +273,13 @@ def query_aggregate_post( index: str, axes: Optional[List[AxisSpec]] = Body(None, description="Axes to aggregate on (i.e. group by)"), aggregations: Optional[List[AggregationSpec]] = Body(None, description="Aggregate functions to compute"), - queries: Optional[Union[str, List[str], Dict[str, str]]] = Body( - None, - description="Query/Queries to run. Value should be a single query string, a list of query strings, " - "or a dict of queries {'label': 'query'}", - ), + queries: Annotated[ + str | list[str] | dict[str, str] | None, + Body( + description="Query/Queries to run. Value should be a single query string, a list of query strings, " + "or a dict of {'label': 'query'}", + ), + ] = None, filters: Annotated[ dict[str, FilterValue | list[FilterValue] | FilterSpec] | None, Body( @@ -287,7 +287,7 @@ def query_aggregate_post( "which can be either a value, a list of values, or a FilterSpec dict", ), ] = None, - _user: str = Depends(authenticated_user), + user: str = Depends(authenticated_user), ): """ Construct an aggregate query. @@ -332,26 +332,29 @@ def query_update_tags( action: Literal["add", "remove"] = Body(None, description="Action (add or remove) on tags"), field: str = Body(None, description="Tag field to update"), tag: str = Body(None, description="Tag to add or remove"), - queries: Optional[Union[str, List[str], Dict[str, str]]] = Body( - None, - description="Query/Queries to run. Value should be a single query string, a list of query strings, " - "or a dict of {'label': 'query'}", - ), - filters: Optional[Dict[str, Union[FilterValue, List[FilterValue], FilterSpec]]] = Body( - None, - description="Field filters, should be a dict of field names to filter specifications," - "which can be either a value, a list of values, or a FilterSpec dict", - ), + queries: Annotated[ + str | list[str] | dict[str, str] | None, + Body( + description="Query/Queries to run. Value should be a single query string, a list of query strings, " + "or a dict of {'label': 'query'}", + ), + ] = None, + filters: Annotated[ + dict[str, FilterValue | list[FilterValue] | FilterSpec] | None, + Body( + description="Field filters, should be a dict of field names to filter specifications," + "which can be either a value, a list of values, or a FilterSpec dict", + ), + ] = None, ids: Optional[Union[str, List[str]]] = Body(None, description="Document IDs of documents to update"), - _user: str = Depends(authenticated_user), + user: str = Depends(authenticated_user), ): """ Add or remove tags by query or by id """ indices = index.split(",") - queries = _standardize_queries(queries) - filters = _standardize_filters(filters) + if isinstance(ids, (str, int)): ids = [ids] - update_tag_query(indices, action, field, tag, queries, filters, ids) + update_tag_query(indices, action, field, tag, _standardize_queries(queries), _standardize_filters(filters), ids) return diff --git a/amcat4/api/users.py b/amcat4/api/users.py index e88b9aa..ff49540 100644 --- a/amcat4/api/users.py +++ b/amcat4/api/users.py @@ -7,8 +7,7 @@ from typing import Literal, Optional from importlib.metadata import version -from fastapi import APIRouter, HTTPException, status, Response -from fastapi.params import Depends +from fastapi import APIRouter, HTTPException, status, Response, Depends from pydantic import BaseModel from pydantic.networks import EmailStr @@ -20,7 +19,7 @@ app_users = APIRouter(tags=["users"]) -ROLE = Literal["ADMIN", "WRITER", "READER", "admin", "writer", "reader"] +ROLE = Literal["ADMIN", "WRITER", "READER", "NONE"] class UserForm(BaseModel): @@ -104,9 +103,13 @@ def modify_user(email: EmailStr, data: ChangeUserForm, _user: str = Depends(auth Modify the given user. Only admin can change users. """ - role = Role[data.role.upper()] - set_global_role(email, role) - return {"email": email, "role": role.name} + if data.role is None or data.role == "NONE": + set_global_role(email, None) + return {"email": email, "role": None} + else: + role = Role[data.role.upper()] + set_global_role(email, role) + return {"email": email, "role": role.name} @app_users.get("/config") diff --git a/amcat4/index.py b/amcat4/index.py index 08e4f1a..62ce207 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -40,10 +40,11 @@ import elasticsearch.helpers from elasticsearch import NotFoundError +from amcat4.api.common import py2dict from amcat4.config import get_settings from amcat4.elastic import es -from amcat4.models import Field, UpdateField, updateField, FieldMetareaderAccess +from amcat4.models import Document, Field, SnippetParams, UpdateField, updateField, FieldMetareaderAccess # The Field model has a type field as we use it in amcat, but we need to @@ -362,7 +363,7 @@ def modify_index( raise ValueError(f"Summary field {summary_field} should be date, keyword or tag, not {f[summary_field].type}!") doc = {x: v for (x, v) in doc.items() if v} if remove_guest_role: - doc["guest_role"] = None + doc["guest_role"] = Role.NONE if doc: es().update(index=get_settings().system_index, id=index, doc=doc) @@ -404,7 +405,7 @@ def get_role(index: str, email: str) -> Role: guest_role: str | None = doc["_source"].get("guest_role", None) if guest_role and guest_role.lower() != "none": return Role[guest_role] - return None + return Role.NONE def get_guest_role(index: str) -> Optional[Role]: @@ -439,19 +440,71 @@ def get_global_role(email: str, only_es: bool = False) -> Optional[Role]: return get_role(index=GLOBAL_ROLES, email=email) -def get_fields(index: str) -> Dict[str, Field]: +def merge_overlapping_fields(index1: str, index2: str, name: str, field1: Field, field2: Field) -> Field: """ - Retrieve the fields settings for this index + Merge two fields, and take most restrictive metareader settings. Also add in_index list. """ - try: - d = es().get( - index=get_settings().system_index, - id=index, - source_includes="fields", + + in_index1 = field1.in_index if field1.in_index is not None else [index1] + in_index2 = field2.in_index if field2.in_index is not None else [index2] + in_index = list(set(in_index1 + in_index2)) + + if field1.type != field2.type: + raise ValueError(f"Field {name} has different types in index {index1} ({field1.type}) and {index2} ({field2.type})") + + if field1.metareader.access == "none" or field2.metareader.access == "none": + return Field(type=field1.type, in_index=in_index, metareader=FieldMetareaderAccess(access="none")) + + if field1.metareader.access == "snippet" and field2.metareader.access == "snippet": + if field1.metareader.max_snippet is None: + return field1 + if field2.metareader.max_snippet is None: + return field2 + + nomatch_chars = min(field1.metareader.max_snippet.nomatch_chars, field2.metareader.max_snippet.nomatch_chars) + max_matches = min(field1.metareader.max_snippet.max_matches, field2.metareader.max_snippet.max_matches) + match_chars = min(field1.metareader.max_snippet.match_chars, field2.metareader.max_snippet.match_chars) + return Field( + type=field1.type, + in_index=in_index, + metareader=FieldMetareaderAccess( + access="snippet", + max_snippet=SnippetParams( + nomatch_chars=nomatch_chars, + max_matches=max_matches, + match_chars=match_chars, + ), + ), ) - except NotFoundError: - raise IndexDoesNotExist(index) - return _fields_from_elastic(d["_source"].get("fields", {})) + + if field1.metareader.access == "snippet" or field2.metareader.access == "snippet": + return Field(type=field1.type, in_index=in_index, metareader=FieldMetareaderAccess(access="snippet")) + return Field(type=field1.type, in_index=in_index, metareader=FieldMetareaderAccess(access="read")) + + +def get_fields(index: str | list[str]) -> dict[str, Field]: + """ + Retrieve the fields settings for this index + """ + fields: dict[str, Field] = {} + indices = [index] if isinstance(index, str) else index + for index in indices: + try: + d = es().get( + index=get_settings().system_index, + id=index, + source_includes="fields", + ) + except NotFoundError: + raise IndexDoesNotExist(index) + + index_fields = _fields_from_elastic(d["_source"].get("fields", {})) + for field, settings in index_fields.items(): + if field not in fields: + fields[field] = settings + else: + fields[field] = merge_overlapping_fields(index, index, field, fields[field], settings) + return fields def list_users(index: str) -> Dict[str, Role]: @@ -514,7 +567,7 @@ def _get_hash(document: dict) -> str: return m.hexdigest() -def upload_documents(index: str, documents, fields: dict[str, Field] | None) -> None: +def upload_documents(index: str, documents: list[Document], fields: dict[str, UpdateField] | None = None) -> None: """ Upload documents to this index @@ -525,7 +578,8 @@ def upload_documents(index: str, documents, fields: dict[str, Field] | None) -> def es_actions(index, documents): field_types = get_fields(index) - for document in documents: + for document_pydantic in documents: + document = py2dict(document_pydantic) for key in document.keys(): if key not in field_types: raise ValueError(f"The type for field {key} is not yet specified") @@ -600,7 +654,7 @@ def get_field_stats(index: str, field: str) -> List[str]: return r["aggregations"]["facets"] -def update_by_query(index: str, script: str, query: dict, params: dict | None = None): +def update_by_query(index: str | list[str], script: str, query: dict, params: dict | None = None): script_dict = dict(source=script, lang="painless", params=params or {}) es().update_by_query(index=index, script=script_dict, **query) @@ -623,7 +677,7 @@ def update_by_query(index: str, script: str, query: dict, params: dict | None = ) -def update_tag_by_query(index: str, action: Literal["add", "remove"], query: dict, field: str, tag: str): +def update_tag_by_query(index: str | list[str], action: Literal["add", "remove"], query: dict, field: str, tag: str): script = TAG_SCRIPTS[action] params = dict(field=field, tag=tag) update_by_query(index, script, query, params) diff --git a/amcat4/models.py b/amcat4/models.py index bfddd36..0c14564 100644 --- a/amcat4/models.py +++ b/amcat4/models.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from typing import Literal, NewType, Union @@ -26,6 +26,7 @@ class Field(BaseModel): type: str metareader: FieldMetareaderAccess + in_index: list[str] | None = None class UpdateField(BaseModel): @@ -66,3 +67,13 @@ class SortSpec(BaseModel): """Form for sort specification.""" order: Literal["asc", "desc"] = "asc" + + +class Document(BaseModel): + """Form to create (upload) a new document.""" + + title: str + date: str + text: str + url: str | None = None + model_config = ConfigDict(extra="allow") diff --git a/amcat4/query.py b/amcat4/query.py index 9fa322c..67b7cea 100644 --- a/amcat4/query.py +++ b/amcat4/query.py @@ -262,10 +262,11 @@ def update_tag_query( action: Literal["add", "remove"], field: str, tag: str, - queries: dict[str, str] | list[str] | None = None, - filters: dict[str, dict] | None = None, + queries: dict[str, str] | None = None, + filters: dict[str, FilterSpec] | None = None, ids: list[str] | None = None, ): """Add or remove tags using a query""" - body = build_body(queries and queries.values(), filters, ids=ids) + body = build_body(queries, filters, ids=ids) + update_tag_by_query(index, action, body, field, tag) diff --git a/setup.py b/setup.py index 4c92880..c9b1d0b 100644 --- a/setup.py +++ b/setup.py @@ -31,14 +31,6 @@ "requests", "class_doc", ], - extras_require={ - "dev": [ - "pytest", - "mypy", - "flake8", - "responses", - "pre-commit", - ] - }, + extras_require={"dev": ["pytest", "mypy", "flake8", "responses", "pre-commit", "types-requests"]}, entry_points={"console_scripts": ["amcat4 = amcat4.__main__:main"]}, ) From 769a372052b87fd28066e0bd8d2ac4ba414fa94f Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Mon, 15 Jan 2024 09:35:11 +0100 Subject: [PATCH 11/80] in progress --- amcat4/__main__.py | 23 ++++++---- amcat4/api/index.py | 8 +--- amcat4/config.py | 108 ++++++++++++++++++++++++-------------------- amcat4/elastic.py | 10 +++- amcat4/index.py | 4 +- amcat4/models.py | 12 +---- 6 files changed, 86 insertions(+), 79 deletions(-) diff --git a/amcat4/__main__.py b/amcat4/__main__.py index bfd3b56..bbd283e 100644 --- a/amcat4/__main__.py +++ b/amcat4/__main__.py @@ -2,7 +2,6 @@ AmCAT4 REST API """ import argparse -import collections import csv import io import json @@ -10,9 +9,11 @@ import os import secrets import sys +from typing import Any import urllib.request from enum import Enum -import elasticsearch +from collections import defaultdict +import elasticsearch.helpers import uvicorn from pydantic.fields import FieldInfo @@ -21,6 +22,7 @@ from amcat4.config import get_settings, AuthOptions, validate_settings from amcat4.elastic import connect_elastic, get_system_version, ping from amcat4.index import GLOBAL_ROLES, create_index, set_global_role, Role, list_global_users, upload_documents +from amcat4.models import UpdateField SOTU_INDEX = "state_of_the_union" @@ -45,7 +47,7 @@ def upload_test_data() -> str: ) for row in csvfile ] - columns = {"president": "keyword", "party": "keyword", "year": "double"} + columns = dict(president=UpdateField(type="keyword"), party=UpdateField(type="keyword"), year=UpdateField(type="double")) upload_documents(SOTU_INDEX, docs, columns) return SOTU_INDEX @@ -77,7 +79,7 @@ def val(val_or_list): return val_or_list -def migrate_index(_args): +def migrate_index(_args) -> None: settings = get_settings() elastic = connect_elastic() if not elastic.ping(): @@ -94,7 +96,8 @@ def migrate_index(_args): else: logging.info("Migrating to version 1") fields = ["index", "email", "role"] - indices = collections.defaultdict(dict) + indices: defaultdict[str, dict[str, str]] = defaultdict(dict) + for entry in elasticsearch.helpers.scan(elastic, index=settings.system_index, fields=fields, _source=False): index, email, role = [val(entry["fields"][field]) for field in fields] indices[index][email] = role @@ -106,10 +109,9 @@ def migrate_index(_args): guest_role = roles_dict.pop("_guest", None) roles_dict.pop("admin", None) roles = [{"email": email, "role": role} for (email, role) in roles_dict.items()] - doc = dict(name=index, guest_role=guest_role, roles=roles) + doc: dict[str, Any] = dict(name=index, guest_role=guest_role, roles=roles) if index == GLOBAL_ROLES: doc["version"] = 1 - print(doc) elastic.index(index=settings.system_index, id=index, document=doc) except Exception: try: @@ -166,9 +168,12 @@ def list_users(_args): admin_password = get_settings().admin_password if admin_password: print("ADMIN : admin (password set via environment AMCAT4_ADMIN_PASSWORD)") - users = sorted(list_global_users(), key=lambda ur: (ur[1], ur[0])) + users = list_global_users() + + # sorted changes the output type of list_global_users? + # users = sorted(list_global_users(), key=lambda ur: (ur[1], ur[0])) if users: - for user, role in users: + for user, role in users.items(): print(f"{role.name:10}: {user}") if not (users or admin_password): print("(No users defined yet, set AMCAT4_ADMIN_PASSWORD in environment use add-admin to add users by email)") diff --git a/amcat4/api/index.py b/amcat4/api/index.py index 69aaf39..d18c46e 100644 --- a/amcat4/api/index.py +++ b/amcat4/api/index.py @@ -11,7 +11,7 @@ from amcat4.api.auth import authenticated_user, authenticated_writer, check_role from amcat4.index import refresh_system_index, remove_role, set_role -from amcat4.models import Document, UpdateField +from amcat4.models import UpdateField app_index = APIRouter(prefix="/index", tags=["index"]) @@ -144,7 +144,7 @@ def delete_index(ix: str, user: str = Depends(authenticated_user)): @app_index.post("/{ix}/documents", status_code=status.HTTP_201_CREATED) def upload_documents( ix: str, - documents: Annotated[list[Document], Body(description="The documents to upload")], + documents: Annotated[list[dict[str, str]], Body(description="The documents to upload")], fields: Annotated[ dict[str, str | UpdateField] | None, Body(description="Optional Specification of field (column) types") ] = None, @@ -154,10 +154,6 @@ def upload_documents( Upload documents to this server. JSON payload should contain a `documents` key, and may contain a `columns` key: - { - "documents": [{"title": .., "date": .., "text": .., ...}, ...], - "columns": {: , ...} - } Returns a list of ids for the uploaded documents """ check_role(user, index.Role.WRITER, ix) diff --git a/amcat4/config.py b/amcat4/config.py index 1c6506c..0f70d60 100644 --- a/amcat4/config.py +++ b/amcat4/config.py @@ -9,7 +9,7 @@ import functools from enum import Enum from pathlib import Path -from typing import Optional +from typing import Annotated from class_doc import extract_docs_from_cls_obj from dotenv import load_dotenv from pydantic import model_validator, Field @@ -43,64 +43,76 @@ def validate(cls, value: str): class Settings(BaseSettings): - env_file: Path = Field( - ".env", - description="Location of a .env file (if used) relative to working directory", - ) - host: str = Field( - "http://localhost:5000", - description="Host this instance is served at (needed for checking tokens)", - ) - - elastic_password: Optional[str] = Field( - None, - description=( - "Elasticsearch password. " - "This the password for the 'elastic' user when Elastic xpack security is enabled" + env_file: Annotated[ + Path, + Field( + description="Location of a .env file (if used) relative to working directory", ), - ) + ] = Path(".env") + host: Annotated[ + str, + Field( + description="Host this instance is served at (needed for checking tokens)", + ), + ] = "http://localhost:5000" - elastic_host: Optional[str] = Field( - None, - description=( - "Elasticsearch host. " - "Default: https://localhost:9200 if elastic_password is set, http://localhost:9200 otherwise" + elastic_password: Annotated[ + str | None, + Field( + description=( + "Elasticsearch password. " "This the password for the 'elastic' user when Elastic xpack security is enabled" + ) + ), + ] = None + + elastic_host: Annotated[ + str | None, + Field( + description=( + "Elasticsearch host. " + "Default: https://localhost:9200 if elastic_password is set, http://localhost:9200 otherwise" + ) ), - ) + ] = None + + elastic_verify_ssl: Annotated[ + bool | None, + Field( + description=( + "Elasticsearch verify SSL (only used if elastic_password is set). " "Default: True unless host is localhost)" + ), + ), + ] = None - elastic_verify_ssl: Optional[bool] = Field( - None, - description=( - "Elasticsearch verify SSL (only used if elastic_password is set). " - "Default: True unless host is localhost)" + system_index: Annotated[ + str, + Field( + description="Elasticsearch index to store authorization information in", ), - ) + ] = "amcat4_system" - system_index: str = Field( - "amcat4_system", - description="Elasticsearch index to store authorization information in", - ) + auth: Annotated[AuthOptions, Field(description="Do we require authorization?")] = AuthOptions.no_auth - auth: AuthOptions = Field( - AuthOptions.no_auth, description="Do we require authorization?" - ) + middlecat_url: Annotated[ + str, + Field( + description="Middlecat server to trust as ID provider", + ), + ] = "https://middlecat.up.railway.app" - middlecat_url: str = Field( - "https://middlecat.up.railway.app", - description="Middlecat server to trust as ID provider", - ) + admin_email: Annotated[ + str | None, + Field( + description="Email address for a hardcoded admin email (useful for setup and recovery)", + ), + ] = None - admin_email: Optional[str] = Field( - None, - description="Email address for a hardcoded admin email (useful for setup and recovery)", - ) + admin_password: Annotated[str | None, Field()] = None @model_validator(mode="after") def set_ssl(self) -> "Settings": if not self.elastic_host: - self.elastic_host = ( - "https" if self.elastic_password else "http" - ) + "://localhost:9200" + self.elastic_host = ("https" if self.elastic_password else "http") + "://localhost:9200" if not self.elastic_verify_ssl: self.elastic_verify_ssl = self.elastic_host not in { "http://localhost:9200", @@ -123,9 +135,7 @@ def get_settings() -> Settings: def validate_settings(): if get_settings().auth != "no_auth": - if get_settings().host.startswith( - "http://" - ) and not get_settings().host.startswith("http://localhost"): + if get_settings().host.startswith("http://") and not get_settings().host.startswith("http://localhost"): return ( "You have set the host at an http address and enabled authentication." "Authentication through middlecat will not work in your browser" diff --git a/amcat4/elastic.py b/amcat4/elastic.py index bb16ffa..ef0a508 100644 --- a/amcat4/elastic.py +++ b/amcat4/elastic.py @@ -45,10 +45,16 @@ def connect_elastic() -> Elasticsearch: """ settings = get_settings() if settings.elastic_password: + host = settings.elastic_host + if settings.elastic_verify_ssl is None: + verify_certs = "localhost" in (host or "") + else: + verify_certs = settings.elastic_verify_ssl + return Elasticsearch( - settings.elastic_host or None, + host, basic_auth=("elastic", settings.elastic_password), - verify_certs=settings.elastic_verify_ssl, + verify_certs=verify_certs, ) else: return Elasticsearch(settings.elastic_host or None) diff --git a/amcat4/index.py b/amcat4/index.py index 62ce207..75940a3 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -44,7 +44,7 @@ from amcat4.config import get_settings from amcat4.elastic import es -from amcat4.models import Document, Field, SnippetParams, UpdateField, updateField, FieldMetareaderAccess +from amcat4.models import Field, SnippetParams, UpdateField, updateField, FieldMetareaderAccess # The Field model has a type field as we use it in amcat, but we need to @@ -567,7 +567,7 @@ def _get_hash(document: dict) -> str: return m.hexdigest() -def upload_documents(index: str, documents: list[Document], fields: dict[str, UpdateField] | None = None) -> None: +def upload_documents(index: str, documents: list[dict[str, str]], fields: dict[str, UpdateField] | None = None) -> None: """ Upload documents to this index diff --git a/amcat4/models.py b/amcat4/models.py index 0c14564..cb81477 100644 --- a/amcat4/models.py +++ b/amcat4/models.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Extra from typing import Literal, NewType, Union @@ -67,13 +67,3 @@ class SortSpec(BaseModel): """Form for sort specification.""" order: Literal["asc", "desc"] = "asc" - - -class Document(BaseModel): - """Form to create (upload) a new document.""" - - title: str - date: str - text: str - url: str | None = None - model_config = ConfigDict(extra="allow") From 6773d4ffd2aaf591815080e0f49fe7a24bcd3ad2 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Mon, 15 Jan 2024 16:41:35 +0100 Subject: [PATCH 12/80] some stuff --- amcat4/api/__init__.py | 24 ++++++++++++------ amcat4/api/auth.py | 3 ++- amcat4/api/query.py | 5 ++-- amcat4/index.py | 55 +++++++++++++++++++++++++++--------------- amcat4/models.py | 14 ++++++++--- amcat4/query.py | 4 +-- 6 files changed, 69 insertions(+), 36 deletions(-) diff --git a/amcat4/api/__init__.py b/amcat4/api/__init__.py index 57b0427..35e8e43 100644 --- a/amcat4/api/__init__.py +++ b/amcat4/api/__init__.py @@ -1,6 +1,7 @@ """AmCAT4 API.""" -from fastapi import FastAPI +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from amcat4.api.index import app_index @@ -15,14 +16,15 @@ dict(name="users", description="Endpoints for user management"), dict(name="index", description="Endpoints to create, list, and delete indices; and to add or modify documents"), dict(name="query", description="Endpoints to list or query documents or run aggregate queries"), - dict(name='middlecat', description="MiddleCat authentication"), + dict(name="middlecat", description="MiddleCat authentication"), dict(name="annotator users", description="Annotator module endpoints for user management"), - dict(name="annotator codingjob", - description="Annotator module endpoints for creating and managing annotator codingjobs, " - "and the core process of getting units and posting annotations"), + dict( + name="annotator codingjob", + description="Annotator module endpoints for creating and managing annotator codingjobs, " + "and the core process of getting units and posting annotations", + ), dict(name="annotator guest", description="Annotator module endpoints for unregistered guests"), - ] - + ], ) app.include_router(app_info) app.include_router(app_users) @@ -35,3 +37,11 @@ allow_methods=["*"], allow_headers=["*"], ) + + +@app.exception_handler(ValueError) +async def value_error_exception_handler(request: Request, exc: ValueError): + return JSONResponse( + status_code=400, + content={"message": str(exc)}, + ) diff --git a/amcat4/api/auth.py b/amcat4/api/auth.py index 095d444..14d6ea7 100644 --- a/amcat4/api/auth.py +++ b/amcat4/api/auth.py @@ -1,4 +1,4 @@ -"""Helper methods for authentication.""" +"""Helper methods for authentication and authorization.""" import functools import logging from datetime import datetime @@ -14,6 +14,7 @@ from amcat4.config import get_settings, AuthOptions from amcat4.index import Role, get_role, get_global_role, get_fields + oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/token", auto_error=False) diff --git a/amcat4/api/query.py b/amcat4/api/query.py index b79895c..79ffa0c 100644 --- a/amcat4/api/query.py +++ b/amcat4/api/query.py @@ -154,7 +154,7 @@ def _standardize_sort(sort: str | list[str] | list[dict[str, SortSpec]] | None = @app_query.post("/{index}/query", response_model=QueryResult) def query_documents_post( - index: str | list[str], + index: str, queries: Annotated[ str | list[str] | dict[str, str] | None, Body( @@ -194,7 +194,6 @@ def query_documents_post( sort: Annotated[ str | list[str] | list[dict[str, SortSpec]] | None, Body( - None, description="Sort by field name(s) or dict (see " "https://www.elastic.co/guide/en/elasticsearch/reference/current/sort-search-results.html for dict format)", openapi_examples={ @@ -231,7 +230,7 @@ def query_documents_post( Returns a JSON object {data: [...], meta: {total_count, per_page, page_count, page|scroll_id}} """ - indices = index if isinstance(index, list) else [index] + indices = index.split(",") fieldspecs = get_or_validate_allowed_fields(user, indices, _standardize_fieldspecs(fields)) r = query.query_documents( diff --git a/amcat4/index.py b/amcat4/index.py index 75940a3..6702557 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -33,18 +33,18 @@ """ import collections from enum import IntEnum -from typing import Dict, Iterable, List, Optional, Literal +from typing import Iterable, Iterator, Optional, Literal import hashlib import json import elasticsearch.helpers from elasticsearch import NotFoundError -from amcat4.api.common import py2dict +# from amcat4.api.common import py2dict from amcat4.config import get_settings from amcat4.elastic import es -from amcat4.models import Field, SnippetParams, UpdateField, updateField, FieldMetareaderAccess +from amcat4.models import Field, FieldClientDisplay, SnippetParams, UpdateField, updateField, FieldMetareaderAccess # The Field model has a type field as we use it in amcat, but we need to @@ -78,10 +78,10 @@ } DEFAULT_INDEX_FIELDS = { - "text": Field(type="text", metareader=DEFAULT_METAREADER["text"]), - "title": Field(type="text", metareader=DEFAULT_METAREADER["text"]), - "date": Field(type="date", metareader=DEFAULT_METAREADER["date"]), - "url": Field(type="url", metareader=DEFAULT_METAREADER["url"]), + "text": Field(type="text", metareader=DEFAULT_METAREADER["text"], client_display=FieldClientDisplay(in_list=True)), + "title": Field(type="text", metareader=DEFAULT_METAREADER["text"], client_display=FieldClientDisplay(in_list=True)), + "date": Field(type="date", metareader=DEFAULT_METAREADER["date"], client_display=FieldClientDisplay(in_list=True)), + "url": Field(type="url", metareader=DEFAULT_METAREADER["url"], client_display=FieldClientDisplay(in_list=True)), } @@ -237,11 +237,11 @@ def deregister_index(index: str, ignore_missing=False) -> None: refresh_index(system_index) -def _roles_from_elastic(roles: List[Dict]) -> Dict[str, Role]: +def _roles_from_elastic(roles: list[dict]) -> dict[str, Role]: return {role["email"]: Role[role["role"].upper()] for role in roles} -def _roles_to_elastic(roles: dict) -> List[Dict]: +def _roles_to_elastic(roles: dict) -> list[dict]: return [{"email": email, "role": role.name} for (email, role) in roles.items()] @@ -284,14 +284,14 @@ def set_guest_role(index: str, guest_role: Optional[Role]): modify_index(index, guest_role=guest_role, remove_guest_role=(guest_role is None)) -def _fields_to_elastic(fields: Dict[str, Field]) -> List[Dict]: +def _fields_to_elastic(fields: dict[str, Field]) -> list[dict]: return [{"field": field, "settings": settings} for field, settings in fields.items()] def _fields_from_elastic( - fields: List[Dict], -) -> Dict[str, Field]: - return {fs["field"]: fs["settings"] for fs in fields} + fields: list[dict], +) -> dict[str, Field]: + return {fs["field"]: Field.model_validate(fs["settings"]) for fs in fields} def set_fields(index: str, new_fields: dict[str, Field] | dict[str, UpdateField]): @@ -482,9 +482,17 @@ def merge_overlapping_fields(index1: str, index2: str, name: str, field1: Field, return Field(type=field1.type, in_index=in_index, metareader=FieldMetareaderAccess(access="read")) +def _get_index_fields(index: str) -> Iterator[tuple[str, str]]: + r = es().indices.get_mapping(index=index) + for k, v in r[index]["mappings"]["properties"].items(): + yield k, v.get("type", "object") + + def get_fields(index: str | list[str]) -> dict[str, Field]: """ - Retrieve the fields settings for this index + Retrieve the fields settings for this index. Look for both the field settings in the system index, + and the field mappings in the index itself. If a field is not defined in the system index, return the + default settings for that field type. """ fields: dict[str, Field] = {} indices = [index] if isinstance(index, str) else index @@ -499,15 +507,22 @@ def get_fields(index: str | list[str]) -> dict[str, Field]: raise IndexDoesNotExist(index) index_fields = _fields_from_elastic(d["_source"].get("fields", {})) - for field, settings in index_fields.items(): + + for field, fieldtype in _get_index_fields(index): + if field not in index_fields: + settings = Field(type=fieldtype, metareader=DEFAULT_METAREADER[fieldtype]) + else: + settings = index_fields[field] + if field not in fields: + settings.in_index = [index] fields[field] = settings else: fields[field] = merge_overlapping_fields(index, index, field, fields[field], settings) return fields -def list_users(index: str) -> Dict[str, Role]: +def list_users(index: str) -> dict[str, Role]: """ " List all users and their roles on the given index :param index: The index to list roles for. @@ -517,7 +532,7 @@ def list_users(index: str) -> Dict[str, Role]: return _roles_from_elastic(r["_source"].get("roles", [])) -def list_global_users() -> Dict[str, Role]: +def list_global_users() -> dict[str, Role]: """ " List all global users and their roles :returns: an iterable of (user, Role) pairs @@ -579,7 +594,7 @@ def upload_documents(index: str, documents: list[dict[str, str]], fields: dict[s def es_actions(index, documents): field_types = get_fields(index) for document_pydantic in documents: - document = py2dict(document_pydantic) + document = document_pydantic.model_dump() for key in document.keys(): if key not in field_types: raise ValueError(f"The type for field {key} is not yet specified") @@ -628,7 +643,7 @@ def delete_document(index: str, doc_id: str): es().delete(index=index, id=doc_id) -def get_field_values(index: str, field: str, size: int) -> List[str]: +def get_field_values(index: str, field: str, size: int) -> list[str]: """ Get the values for a given field (e.g. to populate list of filter values on keyword field) Results are sorted descending by document frequency @@ -643,7 +658,7 @@ def get_field_values(index: str, field: str, size: int) -> List[str]: return [x["key"] for x in r["aggregations"]["unique_values"]["buckets"]] -def get_field_stats(index: str, field: str) -> List[str]: +def get_field_stats(index: str, field: str) -> list[str]: """ :param index: The index :param field: The field name diff --git a/amcat4/models.py b/amcat4/models.py index cb81477..d84ed50 100644 --- a/amcat4/models.py +++ b/amcat4/models.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, ConfigDict, Extra +from pydantic import BaseModel from typing import Literal, NewType, Union @@ -14,10 +14,17 @@ class SnippetParams(BaseModel): match_chars: int = 0 +class FieldClientDisplay(BaseModel): + """Client display settings for a specific field.""" + + in_list: bool = False + in_document: bool = True + + class FieldMetareaderAccess(BaseModel): """Metareader access for a specific field.""" - access: Literal["none", "read", "snippet"] + access: Literal["none", "read", "snippet"] = "none" max_snippet: SnippetParams | None = None @@ -25,7 +32,8 @@ class Field(BaseModel): """Settings for a field.""" type: str - metareader: FieldMetareaderAccess + metareader: FieldMetareaderAccess = FieldMetareaderAccess() + client_display: FieldClientDisplay = FieldClientDisplay() in_index: list[str] | None = None diff --git a/amcat4/query.py b/amcat4/query.py index 67b7cea..a1d88ad 100644 --- a/amcat4/query.py +++ b/amcat4/query.py @@ -180,8 +180,8 @@ def query_documents( body = build_body(queries, filters, h) if fields: - fields = fields if isinstance(fields, list) else list(fields) - kwargs["_source"] = fields + fieldnames = [field.name for field in fields] + kwargs["_source"] = fieldnames if not scroll: kwargs["from_"] = page * per_page result = es().search(index=index, size=per_page, **body, **kwargs) From 17f6022b7bd2e2e3cf64306718996981c0cc875b Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Tue, 16 Jan 2024 12:40:59 +0100 Subject: [PATCH 13/80] new field settigns work --- amcat4/api/auth.py | 13 +++++++------ amcat4/api/index.py | 4 ++-- amcat4/index.py | 34 ++++++++++++++++++---------------- amcat4/models.py | 13 +++++++------ amcat4/query.py | 13 +++++++------ 5 files changed, 41 insertions(+), 36 deletions(-) diff --git a/amcat4/api/auth.py b/amcat4/api/auth.py index 14d6ea7..be941e0 100644 --- a/amcat4/api/auth.py +++ b/amcat4/api/auth.py @@ -146,11 +146,12 @@ def check_fields_access(index: str, user: str, fields: list[FieldSpec]) -> None: if metareader.max_snippet is None: max_params_msg = "" else: - max_params_msg = "Can only read snippet with max parameters:" - f"\n- nomatch_chars = {metareader.max_snippet.nomatch_chars}" - f"\n- max_matches = {metareader.max_snippet.max_matches}" - f"\n- match_chars = {metareader.max_snippet.match_chars}" - + max_params_msg = ( + "Can only read snippet with max parameters:" + f" nomatch_chars={metareader.max_snippet.nomatch_chars}" + f", max_matches={metareader.max_snippet.max_matches}" + f", match_chars={metareader.max_snippet.match_chars}" + ) if field.snippet is None: # if snippet is not specified, the whole field is requested raise HTTPException( @@ -169,7 +170,7 @@ def check_fields_access(index: str, user: str, fields: list[FieldSpec]) -> None: else: raise HTTPException( status_code=401, - detail=f"METAREADER cannot read {field} on index {index}", + detail=f"METAREADER cannot read {field.name} on index {index}", ) diff --git a/amcat4/api/index.py b/amcat4/api/index.py index d18c46e..b5a1068 100644 --- a/amcat4/api/index.py +++ b/amcat4/api/index.py @@ -127,8 +127,8 @@ def view_index(ix: str, user: str = Depends(authenticated_user)): try: role = check_role(user, index.Role.METAREADER, ix, required_global_role=index.Role.WRITER) d = index.get_index(ix)._asdict() - d["user_role"] = role and role.name - d["guest_role"] = d["guest_role"].name if d.get("guest_role") else None + d["user_role"] = role.name + d["guest_role"] = d.get("guest_role", index.Role.NONE.name) return d except index.IndexDoesNotExist: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Index {ix} does not exist") diff --git a/amcat4/index.py b/amcat4/index.py index 6702557..50af74a 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -285,7 +285,7 @@ def set_guest_role(index: str, guest_role: Optional[Role]): def _fields_to_elastic(fields: dict[str, Field]) -> list[dict]: - return [{"field": field, "settings": settings} for field, settings in fields.items()] + return [{"field": field, "settings": settings.model_dump()} for field, settings in fields.items()] def _fields_from_elastic( @@ -301,25 +301,34 @@ def set_fields(index: str, new_fields: dict[str, Field] | dict[str, UpdateField] Note that we're storing fields in two places. We keep all field settings in the system index. But the index that contains the documents also needs to know what fields there are and what their (elastic) types are. So whenever fields are added or the type is updated, we - also udpate the index mapping. + also update the index mapping. """ system_index = get_settings().system_index try: - d = es().get(index=system_index, id=index, source_includes="fields") + es().get(index=system_index, id=index, source_includes="fields") except NotFoundError: raise ValueError(f"Index {index} is not registered") type_mappings = {} - fields = _fields_from_elastic(d["_source"].get("fields", {})) + fields = get_fields(index) + + # Field type specific validation + for field, settings in new_fields.items(): + type = fields.get(field, settings).type + if type != "text": + if settings.metareader and settings.metareader.access == "snippet": + raise ValueError(f"Field {field} is not of type text, cannot set metareader access to snippet") for field, new_settings in new_fields.items(): current = fields.get(field) + if current is None: # Create field if new_settings.type is None: raise ValueError(f"Field {field} does not yet exist, and to create a new field you need to specify a type") type_mappings[field] = ES_MAPPINGS[new_settings.type] + new_settings.metareader = new_settings.metareader or DEFAULT_METAREADER[new_settings.type] fields[field] = Field(**new_settings.model_dump()) else: # Update field @@ -329,7 +338,6 @@ def set_fields(index: str, new_fields: dict[str, Field] | dict[str, UpdateField] raise ValueError( f"Field {field} already exists with type {current.type}, cannot change to {new_settings.type}" ) - # set new field settings (amcat type, metareader, etc.) fields[field] = updateField(current, new_settings) @@ -337,7 +345,7 @@ def set_fields(index: str, new_fields: dict[str, Field] | dict[str, UpdateField] es().update( index=system_index, id=index, - doc=dict(roles=_fields_to_elastic(fields)), + doc=dict(fields=_fields_to_elastic(fields)), ) @@ -442,18 +450,14 @@ def get_global_role(email: str, only_es: bool = False) -> Optional[Role]: def merge_overlapping_fields(index1: str, index2: str, name: str, field1: Field, field2: Field) -> Field: """ - Merge two fields, and take most restrictive metareader settings. Also add in_index list. + Merge two fields, and take most restrictive metareader settings. """ - in_index1 = field1.in_index if field1.in_index is not None else [index1] - in_index2 = field2.in_index if field2.in_index is not None else [index2] - in_index = list(set(in_index1 + in_index2)) - if field1.type != field2.type: raise ValueError(f"Field {name} has different types in index {index1} ({field1.type}) and {index2} ({field2.type})") if field1.metareader.access == "none" or field2.metareader.access == "none": - return Field(type=field1.type, in_index=in_index, metareader=FieldMetareaderAccess(access="none")) + return Field(type=field1.type, metareader=FieldMetareaderAccess(access="none")) if field1.metareader.access == "snippet" and field2.metareader.access == "snippet": if field1.metareader.max_snippet is None: @@ -466,7 +470,6 @@ def merge_overlapping_fields(index1: str, index2: str, name: str, field1: Field, match_chars = min(field1.metareader.max_snippet.match_chars, field2.metareader.max_snippet.match_chars) return Field( type=field1.type, - in_index=in_index, metareader=FieldMetareaderAccess( access="snippet", max_snippet=SnippetParams( @@ -478,8 +481,8 @@ def merge_overlapping_fields(index1: str, index2: str, name: str, field1: Field, ) if field1.metareader.access == "snippet" or field2.metareader.access == "snippet": - return Field(type=field1.type, in_index=in_index, metareader=FieldMetareaderAccess(access="snippet")) - return Field(type=field1.type, in_index=in_index, metareader=FieldMetareaderAccess(access="read")) + return Field(type=field1.type, metareader=FieldMetareaderAccess(access="snippet")) + return Field(type=field1.type, metareader=FieldMetareaderAccess(access="read")) def _get_index_fields(index: str) -> Iterator[tuple[str, str]]: @@ -515,7 +518,6 @@ def get_fields(index: str | list[str]) -> dict[str, Field]: settings = index_fields[field] if field not in fields: - settings.in_index = [index] fields[field] = settings else: fields[field] = merge_overlapping_fields(index, index, field, fields[field], settings) diff --git a/amcat4/models.py b/amcat4/models.py index d84ed50..798cd05 100644 --- a/amcat4/models.py +++ b/amcat4/models.py @@ -1,5 +1,6 @@ +import pydantic from pydantic import BaseModel -from typing import Literal, NewType, Union +from typing import Annotated, Literal class SnippetParams(BaseModel): @@ -9,9 +10,9 @@ class SnippetParams(BaseModel): the first [nomatch_chars] of the field. """ - nomatch_chars: int = 0 - max_matches: int = 0 - match_chars: int = 0 + nomatch_chars: Annotated[int, pydantic.Field(ge=1)] = 1 + max_matches: Annotated[int, pydantic.Field(ge=0)] = 0 + match_chars: Annotated[int, pydantic.Field(ge=1)] = 1 class FieldClientDisplay(BaseModel): @@ -34,7 +35,6 @@ class Field(BaseModel): type: str metareader: FieldMetareaderAccess = FieldMetareaderAccess() client_display: FieldClientDisplay = FieldClientDisplay() - in_index: list[str] | None = None class UpdateField(BaseModel): @@ -42,10 +42,11 @@ class UpdateField(BaseModel): type: str | None = None metareader: FieldMetareaderAccess | None = None + client_display: FieldClientDisplay | None = None def updateField(field: Field, update: UpdateField | Field): - for key in field.model_fields_set: + for key in update.model_fields_set: setattr(field, key, getattr(update, key)) return field diff --git a/amcat4/query.py b/amcat4/query.py index a1d88ad..02a441a 100644 --- a/amcat4/query.py +++ b/amcat4/query.py @@ -215,8 +215,8 @@ def query_highlight(fields: list[FieldSpec], highlight_queries: bool = False) -> """ highlight: dict[str, Any] = { - # "pre_tags": [""] if highlight is True else [""], - # "post_tags": [""] if highlight is True else [""], + "pre_tags": [""] if highlight_queries is True else [""], + "post_tags": [""] if highlight_queries is True else [""], "require_field_match": True, "fields": {}, } @@ -233,11 +233,12 @@ def query_highlight(fields: list[FieldSpec], highlight_queries: bool = False) -> highlight["fields"][field.name] = { "no_match_size": field.snippet.nomatch_chars, "number_of_fragments": field.snippet.max_matches, - "fragment_size": field.snippet.match_chars, + "fragment_size": field.snippet.match_chars or 1, # 0 would return the whole field } - if highlight_queries is False or field.snippet.max_matches == 0: - # This overwrites the actual query, so that the highlights are not returned. - # Also, if max_matches is zero, we drop the query for highlighting so that + print(field.snippet) + if field.snippet.max_matches == 0: + print("heey") + # If max_matches is zero, we drop the query for highlighting so that # the nomatch_chars are returned highlight["fields"][field.name]["highlight_query"] = {"match_all": {}} From bb0d0a970fbda8af15e7bd46354497c8b203fb30 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Tue, 16 Jan 2024 15:01:07 +0100 Subject: [PATCH 14/80] test rewrite hell --- amcat4/__main__.py | 6 +-- amcat4/aggregate.py | 2 +- amcat4/index.py | 99 +++++++++++++---------------------------- amcat4/query.py | 2 - tests/conftest.py | 9 ++-- tests/test_aggregate.py | 15 ++++--- tests/test_api_index.py | 66 ++++++++------------------- tests/test_api_user.py | 26 ++++++----- tests/test_index.py | 16 ++++--- tests/tools.py | 9 ++-- 10 files changed, 91 insertions(+), 159 deletions(-) diff --git a/amcat4/__main__.py b/amcat4/__main__.py index bbd283e..61f92a3 100644 --- a/amcat4/__main__.py +++ b/amcat4/__main__.py @@ -22,7 +22,7 @@ from amcat4.config import get_settings, AuthOptions, validate_settings from amcat4.elastic import connect_elastic, get_system_version, ping from amcat4.index import GLOBAL_ROLES, create_index, set_global_role, Role, list_global_users, upload_documents -from amcat4.models import UpdateField +from amcat4.models import Field SOTU_INDEX = "state_of_the_union" @@ -47,7 +47,7 @@ def upload_test_data() -> str: ) for row in csvfile ] - columns = dict(president=UpdateField(type="keyword"), party=UpdateField(type="keyword"), year=UpdateField(type="double")) + columns = dict(president=Field(type="keyword"), party=Field(type="keyword"), year=Field(type="double")) upload_documents(SOTU_INDEX, docs, columns) return SOTU_INDEX @@ -107,7 +107,7 @@ def migrate_index(_args) -> None: elastic.indices.delete(index=settings.system_index) for index, roles_dict in indices.items(): guest_role = roles_dict.pop("_guest", None) - roles_dict.pop("admin", None) + roles_dict.pop("ADMIN", None) roles = [{"email": email, "role": role} for (email, role) in roles_dict.items()] doc: dict[str, Any] = dict(name=index, guest_role=guest_role, roles=roles) if index == GLOBAL_ROLES: diff --git a/amcat4/aggregate.py b/amcat4/aggregate.py index f881dc8..bbdf895 100644 --- a/amcat4/aggregate.py +++ b/amcat4/aggregate.py @@ -147,7 +147,7 @@ def _elastic_aggregate( aggr["aggs"]["aggregations"] = aggregation_dsl(aggregations) kargs = {} if filters or queries: - q = build_body(queries=queries.values(), filters=filters) + q = build_body(queries=queries, filters=filters) kargs["query"] = q["query"] result = es().search( index=index if isinstance(index, str) else ",".join(index), diff --git a/amcat4/index.py b/amcat4/index.py index 50af74a..722a898 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -171,7 +171,11 @@ def create_index( """ Create a new index in elasticsearch and register it with this AmCAT instance """ - set_fields(index, DEFAULT_INDEX_FIELDS) + default_mapping = {} + for field, settings in DEFAULT_INDEX_FIELDS.items(): + default_mapping[field] = ES_MAPPINGS[settings.type] + + es().indices.create(index=index, mappings={"properties": default_mapping}) register_index( index, guest_role=guest_role, @@ -179,6 +183,7 @@ def create_index( description=description, admin=admin, ) + set_fields(index, DEFAULT_INDEX_FIELDS) def register_index( @@ -303,14 +308,15 @@ def set_fields(index: str, new_fields: dict[str, Field] | dict[str, UpdateField] what their (elastic) types are. So whenever fields are added or the type is updated, we also update the index mapping. """ + system_index = get_settings().system_index try: es().get(index=system_index, id=index, source_includes="fields") + fields = get_fields(index) except NotFoundError: - raise ValueError(f"Index {index} is not registered") + fields = {} type_mappings = {} - fields = get_fields(index) # Field type specific validation for field, settings in new_fields.items(): @@ -329,7 +335,7 @@ def set_fields(index: str, new_fields: dict[str, Field] | dict[str, UpdateField] type_mappings[field] = ES_MAPPINGS[new_settings.type] new_settings.metareader = new_settings.metareader or DEFAULT_METAREADER[new_settings.type] - fields[field] = Field(**new_settings.model_dump()) + fields[field] = Field(**new_settings.model_dump(exclude_none=True)) else: # Update field # it is not possible to update elastic field types, but we can change amcat types (see ES_MAPPINGS) @@ -410,13 +416,10 @@ def get_role(index: str, email: str) -> Role: if index == GLOBAL_ROLES: return Role.NONE - guest_role: str | None = doc["_source"].get("guest_role", None) - if guest_role and guest_role.lower() != "none": - return Role[guest_role] - return Role.NONE + return get_guest_role(index) -def get_guest_role(index: str) -> Optional[Role]: +def get_guest_role(index: str) -> Role: """ Return the guest role for this index, raising a IndexDoesNotExist if the index does not exist :returns: a Role object, or None if global role was NONE @@ -432,7 +435,7 @@ def get_guest_role(index: str) -> Optional[Role]: role = d["_source"].get("guest_role") if role and role.lower() != "none": return Role[role] - return None + return Role.NONE def get_global_role(email: str, only_es: bool = False) -> Optional[Role]: @@ -448,79 +451,36 @@ def get_global_role(email: str, only_es: bool = False) -> Optional[Role]: return get_role(index=GLOBAL_ROLES, email=email) -def merge_overlapping_fields(index1: str, index2: str, name: str, field1: Field, field2: Field) -> Field: - """ - Merge two fields, and take most restrictive metareader settings. - """ - - if field1.type != field2.type: - raise ValueError(f"Field {name} has different types in index {index1} ({field1.type}) and {index2} ({field2.type})") - - if field1.metareader.access == "none" or field2.metareader.access == "none": - return Field(type=field1.type, metareader=FieldMetareaderAccess(access="none")) - - if field1.metareader.access == "snippet" and field2.metareader.access == "snippet": - if field1.metareader.max_snippet is None: - return field1 - if field2.metareader.max_snippet is None: - return field2 - - nomatch_chars = min(field1.metareader.max_snippet.nomatch_chars, field2.metareader.max_snippet.nomatch_chars) - max_matches = min(field1.metareader.max_snippet.max_matches, field2.metareader.max_snippet.max_matches) - match_chars = min(field1.metareader.max_snippet.match_chars, field2.metareader.max_snippet.match_chars) - return Field( - type=field1.type, - metareader=FieldMetareaderAccess( - access="snippet", - max_snippet=SnippetParams( - nomatch_chars=nomatch_chars, - max_matches=max_matches, - match_chars=match_chars, - ), - ), - ) - - if field1.metareader.access == "snippet" or field2.metareader.access == "snippet": - return Field(type=field1.type, metareader=FieldMetareaderAccess(access="snippet")) - return Field(type=field1.type, metareader=FieldMetareaderAccess(access="read")) - - def _get_index_fields(index: str) -> Iterator[tuple[str, str]]: r = es().indices.get_mapping(index=index) for k, v in r[index]["mappings"]["properties"].items(): yield k, v.get("type", "object") -def get_fields(index: str | list[str]) -> dict[str, Field]: +def get_fields(index: str) -> dict[str, Field]: """ Retrieve the fields settings for this index. Look for both the field settings in the system index, and the field mappings in the index itself. If a field is not defined in the system index, return the default settings for that field type. """ fields: dict[str, Field] = {} - indices = [index] if isinstance(index, str) else index - for index in indices: - try: - d = es().get( - index=get_settings().system_index, - id=index, - source_includes="fields", - ) - except NotFoundError: - raise IndexDoesNotExist(index) + try: + d = es().get( + index=get_settings().system_index, + id=index, + source_includes="fields", + ) index_fields = _fields_from_elastic(d["_source"].get("fields", {})) + except NotFoundError: + index_fields = {} - for field, fieldtype in _get_index_fields(index): - if field not in index_fields: - settings = Field(type=fieldtype, metareader=DEFAULT_METAREADER[fieldtype]) - else: - settings = index_fields[field] + for field, fieldtype in _get_index_fields(index): + if field not in index_fields: + fields[field] = Field(type=fieldtype, metareader=DEFAULT_METAREADER[fieldtype]) + else: + fields[field] = index_fields[field] - if field not in fields: - fields[field] = settings - else: - fields[field] = merge_overlapping_fields(index, index, field, fields[field], settings) return fields @@ -595,9 +555,10 @@ def upload_documents(index: str, documents: list[dict[str, str]], fields: dict[s def es_actions(index, documents): field_types = get_fields(index) - for document_pydantic in documents: - document = document_pydantic.model_dump() + for document in documents: for key in document.keys(): + if key == "_id": + continue if key not in field_types: raise ValueError(f"The type for field {key} is not yet specified") document[key] = coerce_type_to_elastic(document[key], field_types[key].type) diff --git a/amcat4/query.py b/amcat4/query.py index 02a441a..fb7fd00 100644 --- a/amcat4/query.py +++ b/amcat4/query.py @@ -235,9 +235,7 @@ def query_highlight(fields: list[FieldSpec], highlight_queries: bool = False) -> "number_of_fragments": field.snippet.max_matches, "fragment_size": field.snippet.match_chars or 1, # 0 would return the whole field } - print(field.snippet) if field.snippet.max_matches == 0: - print("heey") # If max_matches is zero, we drop the query for highlighting so that # the nomatch_chars are returned highlight["fields"][field.name]["highlight_query"] = {"match_all": {}} diff --git a/tests/conftest.py b/tests/conftest.py index d3e2a43..4b515e9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,3 @@ -from typing import Iterable - import pytest import responses from fastapi.testclient import TestClient @@ -17,6 +15,7 @@ set_global_role, upload_documents, ) +from amcat4.models import UpdateField from tests.middlecat_keypair import PUBLIC_KEY UNITS = [ @@ -136,7 +135,7 @@ def guest_index(): delete_index(index, ignore_missing=True) -def upload(index: str, docs: Iterable[dict], **kwargs): +def upload(index: str, docs: list[dict[str, str]], fields: dict[str, UpdateField] | None = None): """ Upload these docs to the index, giving them an incremental id, and flush """ @@ -148,7 +147,7 @@ def upload(index: str, docs: Iterable[dict], **kwargs): for k, v in defaults.items(): if k not in doc: doc[k] = v - upload_documents(index, docs, **kwargs) + upload_documents(index, docs, fields) refresh_index(index) return ids @@ -191,7 +190,7 @@ def populate_index(index): upload( index, TEST_DOCUMENTS, - fields={"cat": "keyword", "subcat": "keyword", "i": "long"}, + fields={"cat": UpdateField(type="keyword"), "subcat": UpdateField(type="keyword"), "i": UpdateField(type="long")}, ) return TEST_DOCUMENTS diff --git a/tests/test_aggregate.py b/tests/test_aggregate.py index 739137e..c01c493 100644 --- a/tests/test_aggregate.py +++ b/tests/test_aggregate.py @@ -2,6 +2,7 @@ from datetime import datetime, date from amcat4.aggregate import query_aggregate, Axis, Aggregation +from amcat4.api.query import _standardize_queries from tests.conftest import upload from tests.tools import dictset @@ -34,9 +35,9 @@ def test_aggregate(index_docs): def test_aggregate_querystring(index_docs): q = functools.partial(do_query, index_docs) - assert q("cat", queries=["toto"]) == {"a": 1, "b": 1} - assert q("cat", queries=["test*"]) == {"a": 2, "b": 1} - assert q("cat", queries=['"a text"', "another"]) == {"a": 2} + assert q("cat", queries=_standardize_queries(["toto"])) == {"a": 1, "b": 1} + assert q("cat", queries=_standardize_queries(["test*"])) == {"a": 2, "b": 1} + assert q("cat", queries=_standardize_queries(['"a text"', "another"])) == {"a": 2} def test_interval(index_docs): @@ -55,18 +56,18 @@ def test_second_axis(index_docs): def test_count(index_docs): """Does aggregation without axes work""" assert do_query(index_docs) == {(): 4} - assert do_query(index_docs, queries=["text"]) == {(): 2} + assert do_query(index_docs, queries={"text": "text"}) == {(): 2} def test_byquery(index_docs): """Get number of documents per query""" - assert do_query(index_docs, Axis("_query"), queries=["text", "test*"]) == {"text": 2, "test*": 3} - assert do_query(index_docs, Axis("_query"), Axis("subcat"), queries=["text", "test*"]) == { + assert do_query(index_docs, Axis("_query"), queries={"text": "text", "test*": "test*"}) == {"text": 2, "test*": 3} + assert do_query(index_docs, Axis("_query"), Axis("subcat"), queries={"text": "text", "test*": "test*"}) == { ("text", "x"): 2, ("test*", "x"): 1, ("test*", "y"): 2, } - assert do_query(index_docs, Axis("subcat"), Axis("_query"), queries=["text", "test*"]) == { + assert do_query(index_docs, Axis("subcat"), Axis("_query"), queries={"text": "text", "test*": "test*"}) == { ("x", "text"): 2, ("x", "test*"): 1, ("y", "test*"): 2, diff --git a/tests/test_api_index.py b/tests/test_api_index.py index 2a5b562..e18175f 100644 --- a/tests/test_api_index.py +++ b/tests/test_api_index.py @@ -1,7 +1,9 @@ from starlette.testclient import TestClient from amcat4 import elastic -from amcat4.index import get_guest_role, Role, set_guest_role, set_role, remove_role + +from amcat4.index import get_guest_role, Role, set_guest_role, set_role, remove_role, set_fields +from amcat4.models import Field from tests.tools import build_headers, post_json, get_json, check, refresh @@ -28,9 +30,7 @@ def test_create_list_delete_index(client, index_name, user, writer, writer2, adm # Users can only see indices that they have a role in or that have a guest role assert index_name not in {x["name"] for x in get_json(client, "/index/", user=user)} - assert index_name not in { - x["name"] for x in get_json(client, "/index/", user=writer2) - } + assert index_name not in {x["name"] for x in get_json(client, "/index/", user=writer2)} assert index_name in {x["name"] for x in get_json(client, "/index/", user=writer)} # (Only) index admin can change index guest role @@ -79,15 +79,13 @@ def test_fields_upload(client: TestClient, user: str, index: str): } for i, x in enumerate(["a", "a", "b"]) ], - "columns": {"x": "keyword"}, + "fields": {"x": "keyword"}, } # You need METAREADER permissions to read fields, and WRITER to upload docs check(client.get(f"/index/{index}/fields"), 401) check( - client.post( - f"/index/{index}/documents", headers=build_headers(user), json=body - ), + client.post(f"/index/{index}/documents", headers=build_headers(user), json=body), 401, ) @@ -96,9 +94,7 @@ def test_fields_upload(client: TestClient, user: str, index: str): assert set(fields.keys()) == {"title", "date", "text", "url"} assert fields["date"]["type"] == "date" check( - client.post( - f"/index/{index}/documents", headers=build_headers(user), json=body - ), + client.post(f"/index/{index}/documents", headers=build_headers(user), json=body), 401, ) @@ -110,14 +106,8 @@ def test_fields_upload(client: TestClient, user: str, index: str): assert doc["title"] == "doc 0" # field selection - assert set( - get_json( - client, f"/index/{index}/documents/0", user=user, params={"fields": "title"} - ).keys() - ) == {"title"} - assert ( - get_json(client, f"/index/{index}/fields", user=user)["x"]["type"] == "keyword" - ) + assert set(get_json(client, f"/index/{index}/documents/0", user=user, params={"fields": "title"}).keys()) == {"title"} + assert get_json(client, f"/index/{index}/fields", user=user)["x"]["type"] == "keyword" elastic.es().indices.refresh() assert set(get_json(client, f"/index/{index}/fields/x/values", user=user)) == { "a", @@ -125,9 +115,7 @@ def test_fields_upload(client: TestClient, user: str, index: str): } -def test_set_get_delete_roles( - client: TestClient, admin: str, writer: str, user: str, index: str -): +def test_set_get_delete_roles(client: TestClient, admin: str, writer: str, user: str, index: str): body = {"email": user, "role": "READER"} # Anon, unauthorized; READER can't add users check(client.post(f"/index/{index}/users", json=body), 401) @@ -159,15 +147,10 @@ def test_set_get_delete_roles( json={"email": writer, "role": "WRITER"}, user=admin, ) - assert get_json(client, f"/index/{index}/users", user=writer) == [ - {"email": writer, "role": "WRITER"} - ] + assert get_json(client, f"/index/{index}/users", user=writer) == [{"email": writer, "role": "WRITER"}] # Writer can now add a new user post_json(client, f"/index/{index}/users", json=body, user=writer) - users = { - u["email"]: u["role"] - for u in get_json(client, f"/index/{index}/users", user=writer) - } + users = {u["email"]: u["role"] for u in get_json(client, f"/index/{index}/users", user=writer)} assert users == {writer: "WRITER", user: "READER"} # Anon, unauthorized or READER can't change users @@ -183,15 +166,10 @@ def test_set_get_delete_roles( client.put(user_url, json={"role": "WRITER"}, headers=build_headers(writer)), 200, ) - users = { - u["email"]: u["role"] - for u in get_json(client, f"/index/{index}/users", user=writer) - } + users = {u["email"]: u["role"] for u in get_json(client, f"/index/{index}/users", user=writer)} assert users == {writer: "WRITER", user: "WRITER"} # Writer can't change to admin - check( - client.put(writer_url, json={"role": "ADMIN"}, headers=build_headers(user)), 401 - ) + check(client.put(writer_url, json={"role": "ADMIN"}, headers=build_headers(user)), 401) # Writer can't change from admin set_role(index, writer, Role.ADMIN) check( @@ -217,18 +195,14 @@ def test_name_description(client, index, index_name, user, admin): check(client.put(f"/index/{index}", json=dict(name="test")), 401) check(client.get(f"/index/{index}"), 401) check( - client.put( - f"/index/{index}", json=dict(name="test"), headers=build_headers(user) - ), + client.put(f"/index/{index}", json=dict(name="test"), headers=build_headers(user)), 401, ) check(client.get(f"/index/{index}", headers=build_headers(user)), 401) # global admin and index writer can change details check( - client.put( - f"/index/{index}", json=dict(name="test"), headers=build_headers(admin) - ), + client.put(f"/index/{index}", json=dict(name="test"), headers=build_headers(admin)), 200, ) set_role(index, user, Role.ADMIN) @@ -257,7 +231,7 @@ def test_name_description(client, index, index_name, user, admin): json=dict( id=index_name, description="test2", - guest_role="metareader", + guest_role="METAREADER", summary_field="party", ), headers=build_headers(admin), @@ -273,7 +247,7 @@ def test_name_description(client, index, index_name, user, admin): assert indices[index_name]["description"] == "test2" # can set and get summary field - elastic.set_fields(index_name, {"party": "keyword"}) + set_fields(index_name, {"party": Field(type="keyword")}) refresh() check( client.put( @@ -283,6 +257,4 @@ def test_name_description(client, index, index_name, user, admin): ), 200, ) - assert ( - get_json(client, f"/index/{index_name}", user=admin)["summary_field"] == "party" - ) + assert get_json(client, f"/index/{index_name}", user=admin)["summary_field"] == "party" diff --git a/tests/test_api_user.py b/tests/test_api_user.py index afebc80..5854e28 100644 --- a/tests/test_api_user.py +++ b/tests/test_api_user.py @@ -45,24 +45,26 @@ def test_get_user(client: TestClient, writer, user): assert get_json(client, f"/users/{user}", user=user) == {"email": user, "role": "READER"} # writer can see everyone assert get_json(client, f"/users/{user}", user=writer) == {"email": user, "role": "READER"} - assert get_json(client, f"/users/{writer}", user=writer) == {"email": writer, "role": 'WRITER'} + assert get_json(client, f"/users/{writer}", user=writer) == {"email": writer, "role": "WRITER"} # Retrieving a non-existing user as admin should give 404 delete_user(user) - assert client.get(f'/users/{user}', headers=build_headers(writer)).status_code == 404 + assert client.get(f"/users/{user}", headers=build_headers(writer)).status_code == 404 def test_create_user(client: TestClient, user, writer, admin, username): # anonymous or unprivileged users cannot create new users - assert client.post('/users/').status_code == 401, "Creating user should require auth" + assert client.post("/users/").status_code == 401, "Creating user should require auth" assert client.post("/users/", headers=build_headers(writer)).status_code == 401, "Creating user should require admin" # users need global role - assert client.post("/users/", headers=build_headers(admin), json=dict(email=username)).status_code == 400, \ - "Duplicate create should return 400" + assert ( + client.post("/users/", headers=build_headers(admin), json=dict(email=username)).status_code == 400 + ), "Duplicate create should return 400" # admin can add new users - u = dict(email=username, role="writer") + u = dict(email=username, role="WRITER") assert "email" in set(post_json(client, "/users/", user=admin, json=u).keys()) - assert client.post("/users/", headers=build_headers(admin), json=u).status_code == 400, \ - "Duplicate create should return 400" + assert ( + client.post("/users/", headers=build_headers(admin), json=u).status_code == 400 + ), "Duplicate create should return 400" # users can delete themselves, others cannot delete them assert client.delete(f"/users/{username}", headers=build_headers(writer)).status_code == 401 @@ -75,8 +77,8 @@ def test_create_user(client: TestClient, user, writer, admin, username): def test_modify_user(client: TestClient, user, writer, admin): """Are the API endpoints and auth for modifying users correct?""" # Only admin can change users - check(client.put(f"/users/{user}", headers=build_headers(user), json={'role': 'metareader'}), 401) - check(client.put(f"/users/{user}", headers=build_headers(admin), json={'role': 'admin'}), 200) + check(client.put(f"/users/{user}", headers=build_headers(user), json={"role": "METAREADER"}), 401) + check(client.put(f"/users/{user}", headers=build_headers(admin), json={"role": "ADMIN"}), 200) assert get_global_role(user).name == "ADMIN" @@ -85,5 +87,5 @@ def test_list_users(client: TestClient, index, admin, user): check(client.get("/users"), 401) check(client.get("/users", headers=build_headers(user)), 401) result = get_json(client, "/users", user=admin) - assert {'email': admin, 'role': 'ADMIN'} in result - assert {'email': user, 'role': 'READER'} in result + assert {"email": admin, "role": "ADMIN"} in result + assert {"email": user, "role": "READER"} in result diff --git a/tests/test_index.py b/tests/test_index.py index 61e5f99..10852a9 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -3,7 +3,7 @@ import pytest from amcat4.config import get_settings -from amcat4.elastic import es, set_fields +from amcat4.elastic import es from amcat4.index import ( Role, create_index, @@ -24,7 +24,9 @@ set_global_role, set_guest_role, set_role, + set_fields, ) +from amcat4.models import Field from tests.tools import refresh @@ -93,7 +95,7 @@ def test_list_indices(index, guest_index, admin): def test_global_roles(): user = "user@example.com" - assert get_global_role(user) is None + assert get_global_role(user) == Role.NONE set_global_role(user, Role.ADMIN) refresh_index(get_settings().system_index) assert get_global_role(user) == Role.ADMIN @@ -102,12 +104,12 @@ def test_global_roles(): assert get_global_role(user) == Role.WRITER remove_global_role(user) refresh_index(get_settings().system_index) - assert get_global_role(user) is None + assert get_global_role(user) == Role.NONE def test_index_roles(index): user = "user@example.com" - assert get_role(index, user) is None + assert get_role(index, user) == Role.NONE set_role(index, user, Role.METAREADER) refresh_index(get_settings().system_index) assert get_role(index, user) == Role.METAREADER @@ -116,11 +118,11 @@ def test_index_roles(index): assert get_role(index, user) == Role.ADMIN remove_role(index, user) refresh_index(get_settings().system_index) - assert get_role(index, user) is None + assert get_role(index, user) == Role.NONE def test_guest_role(index): - assert get_guest_role(index) is None + assert get_guest_role(index) == Role.NONE set_guest_role(index, Role.READER) refresh() assert get_guest_role(index) == Role.READER @@ -163,7 +165,7 @@ def test_summary_field(index): modify_index(index, summary_field="doesnotexist") with pytest.raises(Exception): modify_index(index, summary_field="title") - set_fields(index, {"party": "keyword"}) + set_fields(index, {"party": Field(type="keyword")}) modify_index(index, summary_field="party") assert get_index(index).summary_field == "party" modify_index(index, summary_field="date") diff --git a/tests/tools.py b/tests/tools.py index 1a7e27f..d8b12ea 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -12,7 +12,7 @@ from tests.middlecat_keypair import PRIVATE_KEY -def create_token(**payload) -> bytes: +def create_token(**payload) -> str: header = {"alg": "RS256"} token = jwt.encode(header, payload, PRIVATE_KEY) return token.decode("utf-8") @@ -35,9 +35,7 @@ def get_json(client: TestClient, url, expected=200, headers=None, user=None, **k """Get the given URL. If expected is 2xx, return the result as parsed json""" response = client.get(url, headers=build_headers(user, headers), **kargs) content = response.json() if response.content else None - assert ( - response.status_code == expected - ), f"GET {url} returned {response.status_code}, expected {expected}, {content}" + assert response.status_code == expected, f"GET {url} returned {response.status_code}, expected {expected}, {content}" if expected // 100 == 2: return content @@ -45,8 +43,7 @@ def get_json(client: TestClient, url, expected=200, headers=None, user=None, **k def post_json(client: TestClient, url, expected=201, headers=None, user=None, **kargs): response = client.post(url, headers=build_headers(user, headers), **kargs) assert response.status_code == expected, ( - f"POST {url} returned {response.status_code}, expected {expected}\n" - f"{response.json()}" + f"POST {url} returned {response.status_code}, expected {expected}\n" f"{response.json()}" ) if expected != 204: return response.json() From 5c9663b3f97eb8dc92ca4dd77451c55224dfc82b Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Thu, 18 Jan 2024 17:38:00 +0100 Subject: [PATCH 15/80] progress --- amcat4/api/index.py | 10 ++--- amcat4/api/query.py | 3 +- amcat4/index.py | 10 +++-- amcat4/query.py | 24 +++++------ tests/conftest.py | 4 +- tests/test_api_index.py | 40 ++++++++--------- tests/test_api_metareader.py | 83 +++++++++++++++++++----------------- tests/test_api_user.py | 4 +- tests/test_index.py | 6 +-- tests/test_pagination.py | 28 ++++++------ tests/test_query.py | 12 ++++-- tests/tools.py | 5 +-- 12 files changed, 119 insertions(+), 110 deletions(-) diff --git a/amcat4/api/index.py b/amcat4/api/index.py index b5a1068..86bc5f2 100644 --- a/amcat4/api/index.py +++ b/amcat4/api/index.py @@ -1,6 +1,6 @@ """API Endpoints for document and index management.""" from http import HTTPStatus -from typing import Annotated, Literal +from typing import Annotated, Any, Literal import elasticsearch from elastic_transport import ApiError @@ -144,7 +144,7 @@ def delete_index(ix: str, user: str = Depends(authenticated_user)): @app_index.post("/{ix}/documents", status_code=status.HTTP_201_CREATED) def upload_documents( ix: str, - documents: Annotated[list[dict[str, str]], Body(description="The documents to upload")], + documents: Annotated[list[dict[str, Any]], Body(description="The documents to upload")], fields: Annotated[ dict[str, str | UpdateField] | None, Body(description="Optional Specification of field (column) types") ] = None, @@ -241,11 +241,7 @@ def get_fields(ix: str, user: str = Depends(authenticated_user)): Returns a json array of {name, type} objects """ check_role(user, index.Role.METAREADER, ix) - - if "," in ix: - return index.get_fields(ix.split(",")) - else: - return index.get_fields(ix) + return index.get_fields(ix) @app_index.post("/{ix}/fields") diff --git a/amcat4/api/query.py b/amcat4/api/query.py index 79ffa0c..411fc23 100644 --- a/amcat4/api/query.py +++ b/amcat4/api/query.py @@ -99,7 +99,7 @@ def _standardize_filters( """Convert filters to dict format: {field: {values: []}}.""" if not filters: return None - + print(filters) f = dict() for field, filter_ in filters.items(): if isinstance(filter_, str): @@ -110,6 +110,7 @@ def _standardize_filters( f[field] = filter_ else: raise ValueError(f"Cannot parse filter: {filter_}") + print(f) return f diff --git a/amcat4/index.py b/amcat4/index.py index 722a898..54f3076 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -33,7 +33,7 @@ """ import collections from enum import IntEnum -from typing import Iterable, Iterator, Optional, Literal +from typing import Any, Iterable, Iterator, Optional, Literal import hashlib import json @@ -438,7 +438,7 @@ def get_guest_role(index: str) -> Role: return Role.NONE -def get_global_role(email: str, only_es: bool = False) -> Optional[Role]: +def get_global_role(email: str, only_es: bool = False) -> Role: """ Retrieve the global role of this user @@ -514,6 +514,8 @@ def coerce_type_to_elastic(value, ftype): Coerces values into the respective type in elastic based on ES_MAPPINGS and elastic field types """ + # TODO: aks Wouter what this is based on, and why it doesn't + # actually seem to be based on ES_MAPPINGS if ftype in ["keyword", "constant_keyword", "wildcard", "url", "tag", "text"]: value = str(value) elif ftype in [ @@ -523,7 +525,6 @@ def coerce_type_to_elastic(value, ftype): "double", "float", "half_float", - "half_float", "unsigned_long", ]: value = float(value) @@ -544,7 +545,7 @@ def _get_hash(document: dict) -> str: return m.hexdigest() -def upload_documents(index: str, documents: list[dict[str, str]], fields: dict[str, UpdateField] | None = None) -> None: +def upload_documents(index: str, documents: list[dict[str, Any]], fields: dict[str, UpdateField] | None = None) -> None: """ Upload documents to this index @@ -557,6 +558,7 @@ def es_actions(index, documents): field_types = get_fields(index) for document in documents: for key in document.keys(): + print(key) if key == "_id": continue if key not in field_types: diff --git a/amcat4/query.py b/amcat4/query.py index fb7fd00..1c0ed71 100644 --- a/amcat4/query.py +++ b/amcat4/query.py @@ -4,14 +4,10 @@ from math import ceil from typing import ( - Mapping, - Iterable, - Optional, Union, Sequence, Any, Dict, - List, Tuple, Literal, ) @@ -73,7 +69,7 @@ def parse_queries(queries: dict[str, str]) -> dict: fs, runtime_mappings = [], {} if filters: for field, filter in filters.items(): - extra_runtime_mappings, filter_term = parse_filter(field, filter) + extra_runtime_mappings, filter_term = parse_filter(field, filter.model_dump()) fs.append(filter_term) if extra_runtime_mappings: runtime_mappings.update(extra_runtime_mappings) @@ -125,7 +121,7 @@ def as_dict(self) -> dict: def query_documents( index: Union[str, Sequence[str]], - fields: list[FieldSpec], + fields: list[FieldSpec] | None = None, queries: dict[str, str] | None = None, filters: dict[str, FilterSpec] | None = None, sort: list[dict[str, SortSpec]] | None = None, @@ -136,7 +132,7 @@ def query_documents( scroll_id: str | None = None, highlight: bool = False, **kwargs, -) -> QueryResult | None: +) -> QueryResult: """ Conduct a query_string query, returning the found documents. @@ -145,8 +141,8 @@ def query_documents( If the scroll parameter is given, the result will contain a scroll_id which can be used to get the next batch. In case there are no more documents to scroll, it will return None :param index: The name of the index or indexes - :param fields: List of fields using the FieldSpec syntax. We enforce specific field selection here. Any logic - for determining whether a user can see the field should be done in the API layer. + :param fields: List of fields using the FieldSpec syntax. If not specified, only return _id. + !Any logic for determining whether a user can see the field should be done in the API layer. :param queries: if not None, a dict with labels and queries {label1: query1, ...} :param filters: if not None, a dict where the key is the field and the value is a FilterSpec @@ -174,14 +170,14 @@ def query_documents( if scroll_id: result = es().scroll(scroll_id=scroll_id, **kwargs) if not result["hits"]["hits"]: - return None + return QueryResult(data=[]) else: - h = query_highlight(fields, highlight) + h = query_highlight(fields, highlight) if fields is not None else None body = build_body(queries, filters, h) - if fields: - fieldnames = [field.name for field in fields] - kwargs["_source"] = fieldnames + fieldnames = [field.name for field in fields] if fields is not None else ["_id"] + kwargs["_source"] = fieldnames + if not scroll: kwargs["from_"] = page * per_page result = es().search(index=index, size=per_page, **body, **kwargs) diff --git a/tests/conftest.py b/tests/conftest.py index 4b515e9..96a3b87 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +from typing import Any import pytest import responses from fastapi.testclient import TestClient @@ -135,7 +136,7 @@ def guest_index(): delete_index(index, ignore_missing=True) -def upload(index: str, docs: list[dict[str, str]], fields: dict[str, UpdateField] | None = None): +def upload(index: str, docs: list[dict[str, Any]], fields: dict[str, UpdateField] | None = None): """ Upload these docs to the index, giving them an incremental id, and flush """ @@ -213,6 +214,7 @@ def index_many(): upload( index, [dict(id=i, pagenr=abs(10 - i), text=text) for (i, text) in enumerate(["odd", "even"] * 10)], + fields={"id": UpdateField(type="long"), "pagenr": UpdateField(type="long")}, ) yield index delete_index(index, ignore_missing=True) diff --git a/tests/test_api_index.py b/tests/test_api_index.py index e18175f..a95623d 100644 --- a/tests/test_api_index.py +++ b/tests/test_api_index.py @@ -20,7 +20,7 @@ def test_create_list_delete_index(client, index_name, user, writer, writer2, adm # Writers can create indices post_json(client, "/index/", user=writer, json=dict(id=index_name)) refresh() - assert index_name in {x["name"] for x in get_json(client, "/index/", user=writer)} + assert index_name in {x["name"] for x in get_json(client, "/index/", user=writer) or []} # Users can GET their own index, global writer can GET all indices, others cannot GET non-public indices check(client.get(f"/index/{index_name}"), 401) @@ -29,9 +29,9 @@ def test_create_list_delete_index(client, index_name, user, writer, writer2, adm check(client.get(f"/index/{index_name}", headers=build_headers(user=writer2)), 200) # Users can only see indices that they have a role in or that have a guest role - assert index_name not in {x["name"] for x in get_json(client, "/index/", user=user)} - assert index_name not in {x["name"] for x in get_json(client, "/index/", user=writer2)} - assert index_name in {x["name"] for x in get_json(client, "/index/", user=writer)} + assert index_name not in {x["name"] for x in get_json(client, "/index/", user=user) or []} + assert index_name not in {x["name"] for x in get_json(client, "/index/", user=writer2) or []} + assert index_name in {x["name"] for x in get_json(client, "/index/", user=writer) or []} # (Only) index admin can change index guest role check(client.put(f"/index/{index_name}", json={"guest_role": "METAREADER"}), 401) @@ -62,7 +62,7 @@ def test_create_list_delete_index(client, index_name, user, writer, writer2, adm assert get_guest_role(index_name).name == "READER" # Index should now be visible to non-authorized users - assert index_name in {x["name"] for x in get_json(client, "/index/", user=writer)} + assert index_name in {x["name"] for x in get_json(client, "/index/", user=writer) or []} check(client.get(f"/index/{index_name}", headers=build_headers(user=user)), 200) @@ -90,7 +90,7 @@ def test_fields_upload(client: TestClient, user: str, index: str): ) set_role(index, user, Role.METAREADER) - fields = get_json(client, f"/index/{index}/fields", user=user) + fields = get_json(client, f"/index/{index}/fields", user=user) or {} assert set(fields.keys()) == {"title", "date", "text", "url"} assert fields["date"]["type"] == "date" check( @@ -101,15 +101,17 @@ def test_fields_upload(client: TestClient, user: str, index: str): set_role(index, user, Role.WRITER) post_json(client, f"/index/{index}/documents", user=user, json=body) get_json(client, f"/index/{index}/refresh", expected=204) - doc = get_json(client, f"/index/{index}/documents/0", user=user) + doc = get_json(client, f"/index/{index}/documents/0", user=user) or {} assert set(doc.keys()) == {"date", "text", "title", "x"} assert doc["title"] == "doc 0" # field selection - assert set(get_json(client, f"/index/{index}/documents/0", user=user, params={"fields": "title"}).keys()) == {"title"} - assert get_json(client, f"/index/{index}/fields", user=user)["x"]["type"] == "keyword" + assert set((get_json(client, f"/index/{index}/documents/0", user=user, params={"fields": "title"}) or {}).keys()) == { + "title" + } + assert (get_json(client, f"/index/{index}/fields", user=user) or {})["x"]["type"] == "keyword" elastic.es().indices.refresh() - assert set(get_json(client, f"/index/{index}/fields/x/values", user=user)) == { + assert set(get_json(client, f"/index/{index}/fields/x/values", user=user) or []) == { "a", "b", } @@ -150,7 +152,7 @@ def test_set_get_delete_roles(client: TestClient, admin: str, writer: str, user: assert get_json(client, f"/index/{index}/users", user=writer) == [{"email": writer, "role": "WRITER"}] # Writer can now add a new user post_json(client, f"/index/{index}/users", json=body, user=writer) - users = {u["email"]: u["role"] for u in get_json(client, f"/index/{index}/users", user=writer)} + users = {u["email"]: u["role"] for u in get_json(client, f"/index/{index}/users", user=writer) or []} assert users == {writer: "WRITER", user: "READER"} # Anon, unauthorized or READER can't change users @@ -166,7 +168,7 @@ def test_set_get_delete_roles(client: TestClient, admin: str, writer: str, user: client.put(user_url, json={"role": "WRITER"}, headers=build_headers(writer)), 200, ) - users = {u["email"]: u["role"] for u in get_json(client, f"/index/{index}/users", user=writer)} + users = {u["email"]: u["role"] for u in get_json(client, f"/index/{index}/users", user=writer) or []} assert users == {writer: "WRITER", user: "WRITER"} # Writer can't change to admin check(client.put(writer_url, json={"role": "ADMIN"}, headers=build_headers(user)), 401) @@ -216,14 +218,14 @@ def test_name_description(client, index, index_name, user, admin): ) # global admin and index or guest metareader can read details - assert get_json(client, f"/index/{index}", user=admin)["description"] == "ooktest" - assert get_json(client, f"/index/{index}", user=user)["name"] == "test" + assert (get_json(client, f"/index/{index}", user=admin) or {})["description"] == "ooktest" + assert (get_json(client, f"/index/{index}", user=user) or {})["name"] == "test" set_role(index, user, Role.METAREADER) - assert get_json(client, f"/index/{index}", user=user)["name"] == "test" + assert (get_json(client, f"/index/{index}", user=user) or {})["name"] == "test" set_role(index, user, None) check(client.get(f"/index/{index}", headers=build_headers(user)), 401) set_guest_role(index, Role.METAREADER) - assert get_json(client, f"/index/{index}", user=user)["name"] == "test" + assert (get_json(client, f"/index/{index}", user=user) or {})["name"] == "test" check( client.post( @@ -238,11 +240,11 @@ def test_name_description(client, index, index_name, user, admin): ), 201, ) - assert get_json(client, f"/index/{index_name}", user=user)["description"] == "test2" + assert (get_json(client, f"/index/{index_name}", user=user) or {})["description"] == "test2" # name and description should be present in list of indices refresh() - indices = {ix["id"]: ix for ix in get_json(client, "/index")} + indices = {ix["id"]: ix for ix in get_json(client, "/index") or []} assert indices[index]["description"] == "ooktest" assert indices[index_name]["description"] == "test2" @@ -257,4 +259,4 @@ def test_name_description(client, index, index_name, user, admin): ), 200, ) - assert get_json(client, f"/index/{index_name}", user=admin)["summary_field"] == "party" + assert (get_json(client, f"/index/{index_name}", user=admin) or {})["summary_field"] == "party" diff --git a/tests/test_api_metareader.py b/tests/test_api_metareader.py index 54f1e18..d5d6f26 100644 --- a/tests/test_api_metareader.py +++ b/tests/test_api_metareader.py @@ -1,51 +1,30 @@ from fastapi.testclient import TestClient +from amcat4.models import FieldSpec, SnippetParams -from tests.tools import get_json, build_headers, post_json +from tests.tools import build_headers, post_json def create_index_metareader(client, index, admin): # Create new user and set index role to metareader - client.post( - f"/users", - headers=build_headers(admin), - json={"email": "meta@reader.com", "role": "METAREADER"}, - ), - client.put( - f"/index/{index}/users/meta@reader.com", - headers=build_headers(admin), - json={"role": "METAREADER"}, - ), + client.post("/users", headers=build_headers(admin), json={"email": "meta@reader.com", "role": "METAREADER"}), + client.put(f"/index/{index}/users/meta@reader.com", headers=build_headers(admin), json={"role": "METAREADER"}), -def set_metareader_access(client, index, admin, access): +def set_metareader_access(client, index, admin, metareader): client.post( f"/index/{index}/fields", headers=build_headers(admin), - json={"text": {"type": "text", "meta": {"metareader_access": access}}}, + json={"text": {"type": "text", "metareader": metareader}}, ) -def check_allowed(client, index, field=None, allowed=True): - params = {} - body = {} - - if field: - params["fields"] = field - body["fields"] = [field] - - get_json( - client, - f"/index/{index}/documents", - user="meta@reader.com", - expected=200 if allowed else 401, - params=params, - ) +def check_allowed(client, index: str, field: FieldSpec, allowed=True): post_json( client, f"/index/{index}/query", user="meta@reader.com", expected=200 if allowed else 401, - json=body, + json={"fields": [field.model_dump()]}, ) @@ -55,9 +34,13 @@ def test_metareader_none(client: TestClient, admin, index_docs): Metareader should not be able to get field both full and as snippet """ create_index_metareader(client, index_docs, admin) - set_metareader_access(client, index_docs, admin, "none") - check_allowed(client, index_docs, field="text", allowed=False) - check_allowed(client, index_docs, field="text[150;3;50]", allowed=False) + set_metareader_access(client, index_docs, admin, {"access": "none"}) + + full = FieldSpec(name="text") + snippet = FieldSpec(name="text", snippet=SnippetParams(nomatch_chars=150, max_matches=3, match_chars=50)) + + check_allowed(client, index_docs, full, allowed=False) + check_allowed(client, index_docs, field=snippet, allowed=False) def test_metareader_read(client: TestClient, admin, index_docs): @@ -66,9 +49,13 @@ def test_metareader_read(client: TestClient, admin, index_docs): Metareader should be able to get field both full and as snippet """ create_index_metareader(client, index_docs, admin) - set_metareader_access(client, index_docs, admin, "read") - check_allowed(client, index_docs, field="text", allowed=True) - check_allowed(client, index_docs, field="text[150;3;50]", allowed=True) + set_metareader_access(client, index_docs, admin, {"access": "read"}) + + full = FieldSpec(name="text") + snippet = FieldSpec(name="text", snippet=SnippetParams(nomatch_chars=150, max_matches=3, match_chars=50)) + + check_allowed(client, index_docs, field=full, allowed=True) + check_allowed(client, index_docs, field=snippet, allowed=True) def test_metareader_snippet(client: TestClient, admin, index_docs): @@ -78,9 +65,25 @@ def test_metareader_snippet(client: TestClient, admin, index_docs): with maximum parameters of nomatch_chars=50, max_matches=1, match_chars=20 """ create_index_metareader(client, index_docs, admin) + set_metareader_access( + client, + index_docs, + admin, + {"access": "snippet", "max_snippet": {"nomatch_chars": 50, "max_matches": 1, "match_chars": 20}}, + ) + + full = FieldSpec(name="text") + snippet_too_long = FieldSpec(name="text", snippet=SnippetParams(nomatch_chars=51, max_matches=1, match_chars=20)) + snippet_too_many_matches = FieldSpec(name="text", snippet=SnippetParams(nomatch_chars=50, max_matches=2, match_chars=20)) + snippet_too_long_matches = FieldSpec(name="text", snippet=SnippetParams(nomatch_chars=50, max_matches=1, match_chars=21)) + + snippet_just_right = FieldSpec(name="text", snippet=SnippetParams(nomatch_chars=50, max_matches=1, match_chars=20)) + snippet_less_than_allowed = FieldSpec(name="text", snippet=SnippetParams(nomatch_chars=49, max_matches=0, match_chars=19)) + + check_allowed(client, index_docs, field=full, allowed=False) + check_allowed(client, index_docs, field=snippet_too_long, allowed=False) + check_allowed(client, index_docs, field=snippet_too_many_matches, allowed=False) + check_allowed(client, index_docs, field=snippet_too_long_matches, allowed=False) - set_metareader_access(client, index_docs, admin, "snippet[50;1;20]") - check_allowed(client, index_docs, field="text", allowed=False) - check_allowed(client, index_docs, field="text[51;1;20]", allowed=False) - check_allowed(client, index_docs, field="text[50,1,20]", allowed=True) - check_allowed(client, index_docs, field="text[49;1;20]", allowed=True) + check_allowed(client, index_docs, field=snippet_just_right, allowed=True) + check_allowed(client, index_docs, field=snippet_less_than_allowed, allowed=True) diff --git a/tests/test_api_user.py b/tests/test_api_user.py index 5854e28..37b6bd8 100644 --- a/tests/test_api_user.py +++ b/tests/test_api_user.py @@ -61,7 +61,7 @@ def test_create_user(client: TestClient, user, writer, admin, username): ), "Duplicate create should return 400" # admin can add new users u = dict(email=username, role="WRITER") - assert "email" in set(post_json(client, "/users/", user=admin, json=u).keys()) + assert "email" in set((post_json(client, "/users/", user=admin, json=u) or {}).keys()) assert ( client.post("/users/", headers=build_headers(admin), json=u).status_code == 400 ), "Duplicate create should return 400" @@ -86,6 +86,6 @@ def test_list_users(client: TestClient, index, admin, user): # You need global WRITER rights to list users check(client.get("/users"), 401) check(client.get("/users", headers=build_headers(user)), 401) - result = get_json(client, "/users", user=admin) + result = get_json(client, "/users", user=admin) or {} assert {"email": admin, "role": "ADMIN"} in result assert {"email": user, "role": "READER"} in result diff --git a/tests/test_index.py b/tests/test_index.py index 10852a9..bd12acc 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -37,7 +37,7 @@ def list_es_indices() -> List[str]: return list(es().indices.get(index="*").keys()) -def list_index_names(email: str = None) -> List[str]: +def list_index_names(email: str | None = None) -> List[str]: return [ix.name for ix in list_known_indices(email)] @@ -54,9 +54,9 @@ def test_create_delete_index(): assert index in list_index_names() # Cannot create or register duplicate index with pytest.raises(Exception): - create_index(index.name) + create_index(index) with pytest.raises(Exception): - register_index(index.name) + register_index(index) delete_index(index) refresh_index(get_settings().system_index) assert index not in list_es_indices() diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 7988694..852f56c 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -1,5 +1,4 @@ from typing import List - from amcat4.query import query_documents @@ -12,22 +11,23 @@ def test_pagination(index_many): x = query_documents(index_many, per_page=6, page=3) assert x.page_count == 4 assert x.per_page == 6 - assert len(x.data) == 20 - 3*6 + assert len(x.data) == 20 - 3 * 6 assert x.page == 3 def test_sort(index_many): - def q(key, per_page=5, *args, **kwargs) -> List[int]: - res = query_documents(index_many, per_page=per_page, sort=key, *args, **kwargs) - return [int(h['_id']) for h in res.data] - assert q('id') == [0, 1, 2, 3, 4] - assert q('pagenr') == [10, 9, 11, 8, 12] - assert q(['pagenr', 'id']) == [10, 9, 11, 8, 12] - assert q([{'pagenr': {"order": "desc"}}, 'id']) == [0, 1, 19, 2, 18] + def q(key, per_page=5) -> List[int]: + res = query_documents(index_many, per_page=per_page, sort=key) + return [int(h["_id"]) for h in res.data] + + assert q("id") == [0, 1, 2, 3, 4] + assert q("pagenr") == [10, 9, 11, 8, 12] + assert q(["pagenr", "id"]) == [10, 9, 11, 8, 12] + assert q([{"pagenr": {"order": "desc"}}, "id"]) == [0, 1, 19, 2, 18] def test_scroll(index_many): - r = query_documents(index_many, queries=["odd"], scroll='5m', per_page=4) + r = query_documents(index_many, queries={"odd": "odd"}, scroll="5m", per_page=4) assert len(r.data) == 4 assert r.total_count, 10 assert r.page_count == 3 @@ -41,6 +41,10 @@ def test_scroll(index_many): assert len(r.data) == 2 allids += r.data + print(allids) + # TODO: wth happens when people upload an id field? does it + # overwrite _id, or is _id in this case coincidentally also + # serial int? r = query_documents(index_many, scroll_id=r.scroll_id) - assert r is None - assert {int(h['_id']) for h in allids} == {0, 2, 4, 6, 8, 10, 12, 14, 16, 18} + assert len(r.data) == 0 + assert {int(h["id"]) for h in allids} == {0, 2, 4, 6, 8, 10, 12, 14, 16, 18} diff --git a/tests/test_query.py b/tests/test_query.py index d8d4b39..80623b8 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -1,8 +1,8 @@ import functools -import re from typing import Set, Optional from amcat4 import query +from amcat4.models import FieldSpec from tests.conftest import upload @@ -37,7 +37,7 @@ def test_range_query(index_docs): def test_fields(index_docs): - res = query.query_documents(index_docs, queries=["test"], fields=["cat", "title"]) + res = query.query_documents(index_docs, queries={"1": "test"}, fields=[FieldSpec(name="cat"), FieldSpec(name="title")]) assert res is not None assert set(res.data[0].keys()) == {"cat", "title", "_id"} @@ -46,14 +46,18 @@ def test_highlight(index): words = "The error of regarding functional notions is not quite equivalent to" text = f"{words} a test document. {words} other text documents. {words} you!" upload(index, [dict(title="Een test titel", text=text)]) - res = query.query_documents(index, fields=["title", "text"], queries=["te*"], highlight=True) + res = query.query_documents( + index, fields=[FieldSpec(name="title"), FieldSpec(name="text")], queries={"1": "te*"}, highlight=True + ) assert res is not None doc = res.data[0] assert doc["title"] == "Een test titel" assert doc["text"] == f"{words} a test document. {words} other text documents. {words} you!" # snippets can also have highlights - doc = query.query_documents(index, queries=["te*"], fields=["title"], snippets=["text"], highlight=True).data[0] + doc = query.query_documents( + index, queries={"1": "te*"}, fields=[FieldSpec(name="title")], snippets=["text"], highlight=True + ).data[0] assert doc["title"] == "Een test titel" assert " a test" in doc["text"] assert " ... " in doc["text"] diff --git a/tests/tools.py b/tests/tools.py index d8b12ea..1532f5c 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -3,7 +3,6 @@ from datetime import datetime, date from typing import Set, Iterable, Optional -import requests from authlib.jose import jwt from fastapi.testclient import TestClient @@ -31,7 +30,7 @@ def build_headers(user=None, headers=None): return headers -def get_json(client: TestClient, url, expected=200, headers=None, user=None, **kargs): +def get_json(client: TestClient, url: str, expected=200, headers=None, user=None, **kargs): """Get the given URL. If expected is 2xx, return the result as parsed json""" response = client.get(url, headers=build_headers(user, headers), **kargs) content = response.json() if response.content else None @@ -61,7 +60,7 @@ def dictset(dicts: Iterable[dict]) -> Set[str]: return {json.dumps(dict(sorted(d.items())), cls=DateTimeEncoder) for d in dicts} -def check(response: requests.Response, expected: int, msg: Optional[str] = None): +def check(response, expected: int, msg: Optional[str] = None): assert response.status_code == expected, ( f"{msg or ''}{': ' if msg else ''}Unexpected status: received {response.status_code} != expected {expected};" f" reply: {response.json()}" From 7af640e76fbee9e62c3ce5e9cedeea05d643d150 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Fri, 19 Jan 2024 11:15:19 +0100 Subject: [PATCH 16/80] before I destroy my laptop --- amcat4/__main__.py | 15 ++++++++------- amcat4/config.py | 8 +++++--- amcat4/date_mappings.py | 2 +- amcat4/index.py | 5 ++--- 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/amcat4/__main__.py b/amcat4/__main__.py index 61f92a3..76224f5 100644 --- a/amcat4/__main__.py +++ b/amcat4/__main__.py @@ -22,7 +22,7 @@ from amcat4.config import get_settings, AuthOptions, validate_settings from amcat4.elastic import connect_elastic, get_system_version, ping from amcat4.index import GLOBAL_ROLES, create_index, set_global_role, Role, list_global_users, upload_documents -from amcat4.models import Field +from amcat4.models import UpdateField SOTU_INDEX = "state_of_the_union" @@ -47,7 +47,7 @@ def upload_test_data() -> str: ) for row in csvfile ] - columns = dict(president=Field(type="keyword"), party=Field(type="keyword"), year=Field(type="double")) + columns = dict(president=UpdateField(type="keyword"), party=UpdateField(type="keyword"), year=UpdateField(type="double")) upload_documents(SOTU_INDEX, docs, columns) return SOTU_INDEX @@ -188,7 +188,7 @@ def config_amcat(args): for fieldname in settings.model_fields_set: if fieldname not in settings_dict: continue - fieldinfo = settings.model_fields[fieldname] + fieldinfo = settings_dict[fieldname] validation_function = AuthOptions.validate if fieldname == "auth" else None value = getattr(settings, fieldname) value = menu(fieldname, fieldinfo, value, validation_function=validation_function) @@ -199,7 +199,7 @@ def config_amcat(args): with env_file_location.open("w") as f: for fieldname, value in settings_dict.items(): - fieldinfo = settings.model_fields[fieldname] + fieldinfo = settings_dict[fieldname] if doc := fieldinfo.description: f.write(f"# {doc}\n") if _isenum(fieldinfo): @@ -225,16 +225,17 @@ def bold(x): def _isenum(fieldinfo: FieldInfo) -> bool: try: - return issubclass(fieldinfo.annotation, Enum) + return issubclass(fieldinfo.annotation, Enum) if fieldinfo.annotation is not None else False except TypeError: return False def menu(fieldname: str, fieldinfo: FieldInfo, value, validation_function=None): print(f"\n{bold(fieldname)}: {fieldinfo.description}") - if _isenum(fieldinfo): + if _isenum(fieldinfo) and fieldinfo.annotation: print(" Possible choices:") - for option in fieldinfo.annotation: + options: Any = fieldinfo.annotation + for option in options: print(f" - {option.name}: {option.__doc__}") print() print(f"The current value for {bold(fieldname)} is {bold(value)}.") diff --git a/amcat4/config.py b/amcat4/config.py index 0f70d60..6d339f2 100644 --- a/amcat4/config.py +++ b/amcat4/config.py @@ -15,6 +15,8 @@ from pydantic import model_validator, Field from pydantic_settings import BaseSettings, SettingsConfigDict +ENV_PREFIX = "amcat4_" + class AuthOptions(str, Enum): #: everyone (that can reach the server) can do anything they want @@ -120,7 +122,7 @@ def set_ssl(self) -> "Settings": } return self - model_config = SettingsConfigDict(env_prefix="amcat4_") + model_config = SettingsConfigDict(env_prefix=ENV_PREFIX) @functools.lru_cache() @@ -145,5 +147,5 @@ def validate_settings(): if __name__ == "__main__": # Echo the settings - for k, v in get_settings().dict().items(): - print(f"{Settings.Config.env_prefix.upper()}{k.upper()}={v}") + for k, v in get_settings().model_dump().items(): + print(f"{ENV_PREFIX.upper()}{k.upper()}={v}") diff --git a/amcat4/date_mappings.py b/amcat4/date_mappings.py index 0dc77c6..fcb3d6f 100644 --- a/amcat4/date_mappings.py +++ b/amcat4/date_mappings.py @@ -3,7 +3,7 @@ class DateMapping: - interval = None + interval: str | None = None def mapping(self, field: str) -> dict: return {self.fieldname(field): {"type": self.mapping_type(), "script": self.mapping_script(field)}} diff --git a/amcat4/index.py b/amcat4/index.py index 54f3076..6cc6232 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -345,7 +345,7 @@ def set_fields(index: str, new_fields: dict[str, Field] | dict[str, UpdateField] f"Field {field} already exists with type {current.type}, cannot change to {new_settings.type}" ) # set new field settings (amcat type, metareader, etc.) - fields[field] = updateField(current, new_settings) + fields[field] = updateField(field=current, update=new_settings) es().indices.put_mapping(index=index, properties=type_mappings) es().update( @@ -551,14 +551,13 @@ def upload_documents(index: str, documents: list[dict[str, Any]], fields: dict[s :param index: The name of the index (without prefix) :param documents: A sequence of article dictionaries - :param fields: A mapping of field:type for field types + :param fields: A mapping of fieldname:UpdateField for field types """ def es_actions(index, documents): field_types = get_fields(index) for document in documents: for key in document.keys(): - print(key) if key == "_id": continue if key not in field_types: From a2a401a7d9dbc125081dc890941c35477e1d292d Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Fri, 19 Jan 2024 15:52:42 +0100 Subject: [PATCH 17/80] fixing tests --- amcat4/api/query.py | 6 ++-- amcat4/query.py | 6 ++-- tests/test_api_pagination.py | 11 +++---- tests/test_pagination.py | 14 ++++----- tests/test_query.py | 61 ++++++++++++++++++++++++++---------- 5 files changed, 61 insertions(+), 37 deletions(-) diff --git a/amcat4/api/query.py b/amcat4/api/query.py index 411fc23..b2e1913 100644 --- a/amcat4/api/query.py +++ b/amcat4/api/query.py @@ -3,6 +3,7 @@ from typing import Annotated, Dict, List, Optional, Any, Union, Iterable, Literal from fastapi import APIRouter, HTTPException, status, Depends, Response, Body +from pydantic import InstanceOf from pydantic.main import BaseModel from amcat4 import query, aggregate @@ -99,8 +100,8 @@ def _standardize_filters( """Convert filters to dict format: {field: {values: []}}.""" if not filters: return None - print(filters) - f = dict() + + f: dict[str, FilterSpec] = {} for field, filter_ in filters.items(): if isinstance(filter_, str): f[field] = FilterSpec(values=[filter_]) @@ -110,7 +111,6 @@ def _standardize_filters( f[field] = filter_ else: raise ValueError(f"Cannot parse filter: {filter_}") - print(f) return f diff --git a/amcat4/query.py b/amcat4/query.py index 1c0ed71..0ed9ebd 100644 --- a/amcat4/query.py +++ b/amcat4/query.py @@ -25,8 +25,8 @@ def build_body( highlight: dict | None = None, ids: list[str] | None = None, ): - def parse_filter(field, filter) -> Tuple[dict, dict]: - filter = filter.copy() + def parse_filter(field: str, filterSpec: FilterSpec) -> Tuple[dict, dict]: + filter = filterSpec.model_dump(exclude_none=True) extra_runtime_mappings = {} field_filters = [] for value in filter.pop("values", []): @@ -69,7 +69,7 @@ def parse_queries(queries: dict[str, str]) -> dict: fs, runtime_mappings = [], {} if filters: for field, filter in filters.items(): - extra_runtime_mappings, filter_term = parse_filter(field, filter.model_dump()) + extra_runtime_mappings, filter_term = parse_filter(field, filter) fs.append(filter_term) if extra_runtime_mappings: runtime_mappings.update(extra_runtime_mappings) diff --git a/tests/test_api_pagination.py b/tests/test_api_pagination.py index 8139013..ac58863 100644 --- a/tests/test_api_pagination.py +++ b/tests/test_api_pagination.py @@ -7,11 +7,12 @@ def test_pagination(client, index, user): """Does basic pagination work?""" set_role(index, user, Role.READER) + # TODO. Tests are not independent. test_pagination fails if run directly after other tests. + # Probably delete_index doesn't fully delete + upload(index, docs=[{"i": i} for i in range(66)]) url = f"/index/{index}/documents" - r = get_json( - client, url, user=user, params={"sort": "i", "per_page": 20, "fields": ["i"]} - ) + r = get_json(client, url, user=user, params={"sort": "i", "per_page": 20, "fields": ["i"]}) assert r["meta"]["per_page"] == 20 assert r["meta"]["page"] == 0 assert r["meta"]["page_count"] == 4 @@ -24,9 +25,7 @@ def test_pagination(client, index, user): ) assert r["meta"]["page"] == 3 assert {h["i"] for h in r["results"]} == {60, 61, 62, 63, 64, 65} - r = get_json( - client, url, user=user, params={"sort": "i", "per_page": 20, "page": 4} - ) + r = get_json(client, url, user=user, params={"sort": "i", "per_page": 20, "page": 4}) assert len(r["results"]) == 0 # Test POST query diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 852f56c..4f63a66 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -1,4 +1,5 @@ from typing import List +from amcat4.models import FieldSpec from amcat4.query import query_documents @@ -27,24 +28,21 @@ def q(key, per_page=5) -> List[int]: def test_scroll(index_many): - r = query_documents(index_many, queries={"odd": "odd"}, scroll="5m", per_page=4) + r = query_documents(index_many, queries={"odd": "odd"}, scroll="5m", per_page=4, fields=[FieldSpec(name="id")]) assert len(r.data) == 4 assert r.total_count, 10 assert r.page_count == 3 allids = list(r.data) - r = query_documents(index_many, scroll_id=r.scroll_id) + r = query_documents(index_many, scroll_id=r.scroll_id, fields=[FieldSpec(name="id")]) assert len(r.data) == 4 allids += r.data - r = query_documents(index_many, scroll_id=r.scroll_id) + r = query_documents(index_many, scroll_id=r.scroll_id, fields=[FieldSpec(name="id")]) assert len(r.data) == 2 allids += r.data - print(allids) - # TODO: wth happens when people upload an id field? does it - # overwrite _id, or is _id in this case coincidentally also - # serial int? - r = query_documents(index_many, scroll_id=r.scroll_id) + r = query_documents(index_many, scroll_id=r.scroll_id, fields=[FieldSpec(name="id")]) + assert len(r.data) == 0 assert {int(h["id"]) for h in allids} == {0, 2, 4, 6, 8, 10, 12, 14, 16, 18} diff --git a/tests/test_query.py b/tests/test_query.py index 80623b8..30b43a1 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -2,13 +2,22 @@ from typing import Set, Optional from amcat4 import query -from amcat4.models import FieldSpec +from amcat4.models import FieldSpec, FilterSpec, FilterValue, SnippetParams, UpdateField +from amcat4.api.query import _standardize_queries, _standardize_filters from tests.conftest import upload -def query_ids(index: str, q: Optional[str] = None, **kwargs) -> Set[int]: +def query_ids( + index: str, + q: Optional[str | list[str]] = None, + filters: dict[str, FilterValue | list[FilterValue] | FilterSpec] | None = None, + **kwargs, +) -> Set[int]: if q is not None: - kwargs["queries"] = [q] + kwargs["queries"] = _standardize_queries(q) + if filters is not None: + kwargs["filters"] = _standardize_filters(filters) + res = query.query_documents(index, **kwargs) if res is None: return set() @@ -17,23 +26,34 @@ def query_ids(index: str, q: Optional[str] = None, **kwargs) -> Set[int]: def test_query(index_docs): q = functools.partial(query_ids, index_docs) + assert q("test") == {1, 2} assert q("test*") == {1, 2, 3} assert q('"a text"') == {0} - assert q(queries=["this", "toto"]) == {0, 2, 3} + assert q(["this", "toto"]) == {0, 2, 3} - assert q(filters={"title": {"value": "title"}}) == {0, 1} - assert q("this", filters={"title": {"value": "title"}}) == {0} + assert q(filters={"title": ["title"]}) == {0, 1} + assert q("this", filters={"title": ["title"]}) == {0} assert q("this") == {0, 2} +def test_snippet(index_docs): + docs = query.query_documents(index_docs, fields=[FieldSpec(name="text", snippet=SnippetParams(nomatch_chars=5))]) + assert docs.data[0]["text"] == "this is" + + docs = query.query_documents( + index_docs, queries={"1": "a"}, fields=[FieldSpec(name="text", snippet=SnippetParams(max_matches=1, match_chars=1))] + ) + assert docs.data[0]["text"] == "a" + + def test_range_query(index_docs): q = functools.partial(query_ids, index_docs) - assert q(filters={"date": {"gt": "2018-02-01"}}) == {2} - assert q(filters={"date": {"gte": "2018-02-01"}}) == {1, 2} - assert q(filters={"date": {"gte": "2018-02-01", "lt": "2020-01-01"}}) == {1} - assert q("title", filters={"date": {"gt": "2018-01-01"}}) == {1} + assert q(filters={"date": FilterSpec(gt="2018-02-01")}) == {2} + assert q(filters={"date": FilterSpec(gte="2018-02-01")}) == {1, 2} + assert q(filters={"date": FilterSpec(gte="2018-02-01", lt="2020-01-01")}) == {1} + assert q("title", filters={"date": FilterSpec(gt="2018-01-01")}) == {1} def test_fields(index_docs): @@ -54,9 +74,14 @@ def test_highlight(index): assert doc["title"] == "Een test titel" assert doc["text"] == f"{words} a test document. {words} other text documents. {words} you!" - # snippets can also have highlights doc = query.query_documents( - index, queries={"1": "te*"}, fields=[FieldSpec(name="title")], snippets=["text"], highlight=True + index, + queries={"1": "te*"}, + fields=[ + FieldSpec(name="title", snippet=SnippetParams(max_matches=3, match_chars=50)), + FieldSpec(name="text", snippet=SnippetParams(max_matches=3, match_chars=50)), + ], + highlight=True, ).data[0] assert doc["title"] == "Een test titel" assert " a test" in doc["text"] @@ -64,13 +89,15 @@ def test_highlight(index): def test_query_multiple_index(index_docs, index): - upload(index, [{"text": "also a text", "i": -1}]) + upload(index, [{"text": "also a text", "i": -1}], fields={"i": UpdateField(type="long")}) docs = query.query_documents([index_docs, index]) assert docs is not None assert len(docs.data) == 5 -def test_query_filter_mapping(index_docs): - q = functools.partial(query_ids, index_docs) - assert q(filters={"date": {"monthnr": "2"}}) == {1} - assert q(filters={"date": {"dayofweek": "Monday"}}) == {0, 3} +# TODO: Do we want to support this? What are the options? +# If so, need to add it to FilterSpec +# def test_query_filter_mapping(index_docs): +# q = functools.partial(query_ids, index_docs) +# assert q(filters={"date": {"monthnr": "2"}}) == {1} +# assert q(filters={"date": {"dayofweek": "Monday"}}) == {0, 3} From e668cb2694312e0203132bb070b7f021b457c246 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Wed, 24 Jan 2024 12:03:10 +0100 Subject: [PATCH 18/80] redecorating. Separating amcat and elastic type --- amcat4/__main__.py | 6 +- amcat4/aggregate.py | 2 +- amcat4/api/auth.py | 4 +- amcat4/api/index.py | 72 ++++++------ amcat4/fields.py | 267 ++++++++++++++++++++++++++++++++++++++++++ amcat4/index.py | 219 ++++------------------------------ amcat4/models.py | 46 +++++++- tests/test_elastic.py | 4 +- 8 files changed, 374 insertions(+), 246 deletions(-) create mode 100644 amcat4/fields.py diff --git a/amcat4/__main__.py b/amcat4/__main__.py index 76224f5..b9eb183 100644 --- a/amcat4/__main__.py +++ b/amcat4/__main__.py @@ -22,7 +22,7 @@ from amcat4.config import get_settings, AuthOptions, validate_settings from amcat4.elastic import connect_elastic, get_system_version, ping from amcat4.index import GLOBAL_ROLES, create_index, set_global_role, Role, list_global_users, upload_documents -from amcat4.models import UpdateField +from amcat4.models import ElasticType SOTU_INDEX = "state_of_the_union" @@ -47,8 +47,8 @@ def upload_test_data() -> str: ) for row in csvfile ] - columns = dict(president=UpdateField(type="keyword"), party=UpdateField(type="keyword"), year=UpdateField(type="double")) - upload_documents(SOTU_INDEX, docs, columns) + fields: dict[str, ElasticType] = {"president": "keyword", "party": "keyword", "year": "short"} + upload_documents(SOTU_INDEX, docs, fields) return SOTU_INDEX diff --git a/amcat4/aggregate.py b/amcat4/aggregate.py index bbdf895..ab768ef 100644 --- a/amcat4/aggregate.py +++ b/amcat4/aggregate.py @@ -6,7 +6,7 @@ from amcat4.date_mappings import interval_mapping from amcat4.elastic import es -from amcat4.index import get_fields +from amcat4.fields import get_fields from amcat4.query import build_body from amcat4.models import Field, FilterSpec diff --git a/amcat4/api/auth.py b/amcat4/api/auth.py index be941e0..04c55ed 100644 --- a/amcat4/api/auth.py +++ b/amcat4/api/auth.py @@ -12,8 +12,8 @@ from amcat4.models import FieldSpec from amcat4.config import get_settings, AuthOptions -from amcat4.index import Role, get_role, get_global_role, get_fields - +from amcat4.index import Role, get_role, get_global_role +from amcat4.fields import get_fields oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/token", auto_error=False) diff --git a/amcat4/api/index.py b/amcat4/api/index.py index 86bc5f2..2a73721 100644 --- a/amcat4/api/index.py +++ b/amcat4/api/index.py @@ -11,27 +11,14 @@ from amcat4.api.auth import authenticated_user, authenticated_writer, check_role from amcat4.index import refresh_system_index, remove_role, set_role -from amcat4.models import UpdateField +from amcat4.fields import field_values, field_stats +from amcat4.models import CreateField, ElasticType, Field, UpdateField app_index = APIRouter(prefix="/index", tags=["index"]) RoleType = Literal["ADMIN", "WRITER", "READER", "METAREADER"] -def _standardize_updatefields(fields: dict[str, str | UpdateField]) -> dict[str, UpdateField]: - standardized_fields: dict[str, UpdateField] = {} - - for name, field in fields.items(): - if isinstance(field, UpdateField): - standardized_fields[name] = field - elif isinstance(field, str): - standardized_fields[name] = UpdateField(type=field) - else: - raise ValueError(f"Cannot parse field: {field}") - - return standardized_fields - - @app_index.get("/") def index_list(current_user: str = Depends(authenticated_user)): """ @@ -145,23 +132,18 @@ def delete_index(ix: str, user: str = Depends(authenticated_user)): def upload_documents( ix: str, documents: Annotated[list[dict[str, Any]], Body(description="The documents to upload")], - fields: Annotated[ - dict[str, str | UpdateField] | None, Body(description="Optional Specification of field (column) types") + types: Annotated[ + dict[str, ElasticType] | None, + Body(description="If a field in documents does not yet exist, you need to specify an elastic type"), ] = None, user: str = Depends(authenticated_user), ): """ - Upload documents to this server. - - JSON payload should contain a `documents` key, and may contain a `columns` key: - Returns a list of ids for the uploaded documents + Upload documents to this server. Returns a list of ids for the uploaded documents """ check_role(user, index.Role.WRITER, ix) - if fields is None: - return index.upload_documents(ix, documents) - else: - return index.upload_documents(ix, documents, _standardize_updatefields(fields)) + return index.upload_documents(ix, documents, types) @app_index.get("/{ix}/documents/{docid}") @@ -233,6 +215,31 @@ def delete_document(ix: str, docid: str, user: str = Depends(authenticated_user) ) +@app_index.post("/{ix}/fields") +def create_fields( + ix: str, + fields: Annotated[dict[str, ElasticType | CreateField], Body(description="")], + user: str = Depends(authenticated_user), +): + """ + Create fields + """ + check_role(user, index.Role.WRITER, ix) + + types = {} + update_fields = {} + for field, value in fields.items(): + if isinstance(value, CreateField): + types[field] = value.elastic_type + update_fields[field] = UpdateField(**value.model_dump(exclude_none=True)) + else: + types[field] = value + + if len(update_fields) > 0: + index.set_fields(ix, update_fields) + return "", HTTPStatus.NO_CONTENT + + @app_index.get("/{ix}/fields") def get_fields(ix: str, user: str = Depends(authenticated_user)): """ @@ -244,17 +251,16 @@ def get_fields(ix: str, user: str = Depends(authenticated_user)): return index.get_fields(ix) -@app_index.post("/{ix}/fields") -def set_fields( - ix: str, fields: Annotated[dict[str, str | UpdateField], Body(description="")], user: str = Depends(authenticated_user) +@app_index.put("/{ix}/fields") +def update_fields( + ix: str, fields: Annotated[dict[str, UpdateField], Body(description="")], user: str = Depends(authenticated_user) ): """ - Set the field types used in this index. - + Update the field settings """ check_role(user, index.Role.WRITER, ix) - index.set_fields(ix, _standardize_updatefields(fields)) + index.set_fields(ix, fields) return "", HTTPStatus.NO_CONTENT @@ -270,7 +276,7 @@ def get_field_values(ix: str, field: str, user: str = Depends(authenticated_user efficient, since elastic has to aggregate all values first. """ check_role(user, index.Role.READER, ix) - values = index.get_field_values(ix, field, size=2001) + values = field_values(ix, field, size=2001) if len(values) > 2000: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -283,7 +289,7 @@ def get_field_values(ix: str, field: str, user: str = Depends(authenticated_user def get_field_stats(ix: str, field: str, user: str = Depends(authenticated_user)): """Get statistics for a specific value. Only works for numeric (incl date) fields.""" check_role(user, index.Role.READER, ix) - return index.get_field_stats(ix, field) + return field_stats(ix, field) @app_index.get("/{ix}/users") diff --git a/amcat4/fields.py b/amcat4/fields.py new file mode 100644 index 0000000..1316eb8 --- /dev/null +++ b/amcat4/fields.py @@ -0,0 +1,267 @@ +""" +We have two types of fields: +- Elastic fields are the fields used under the hood by elastic. (https://www.elastic.co/guide/en/elasticsearch/reference/current/mapping-types.html + These are stored in the Mapping of an index +- Amcat fields (Field) are the fields are seen by the amcat user. They use a simplified type, and contain additional information such as metareader access + These are stored in the system index. + +We need to make sure that: +- When a user sets a field, it needs to be changed in both types: the system index and the mapping +- If a field only exists in the elastic mapping, we need to add the default Field to the system index. + This happens anytime get_fields is called, so that whenever a field is used it is guarenteed to be in the + system index +""" + + +from typing import Any, Iterator + + +from elasticsearch import NotFoundError + +# from amcat4.api.common import py2dict +from amcat4.config import get_settings +from amcat4.elastic import es +from amcat4.models import AmcatType, ElasticType, Field, FieldClientDisplay, UpdateField, updateField, FieldMetareaderAccess + + +# given an elastic field type, infer +# (this is relevant if we are importing an index that does not yet have ) +TYPEMAP_ES_TO_AMCAT: dict[ElasticType, AmcatType] = { + # TEXT fields + "text": "text", + "annotated_text": "text", + "binary": "text", + "match_only_text": "text", + # DATE fields + "date": "date", + # BOOLEAN fields + "boolean": "boolean", + # KEYWORD fields + "keyword": "keyword", + "constant_keyword": "keyword", + "wildcard": "keyword", + # NUMBER fields + # - integer + "integer": "number", + "byte": "number", + "short": "number", + "long": "number", + "unsigned_long": "number", + # - float + "float": "number", + "half_float": "number", + "double": "number", + "scaled_float": "number", + # OBJECT fields + "object": "object", + "flattened": "object", + "nested": "object", + # VECTOR fields (exclude sparse vectors) + "dense_vector": "vector", + # GEO fields + "geo_point": "geo", +} + + +def get_default_metareader(amcat_type: AmcatType): + if amcat_type in ["boolean", "number", "date"]: + return FieldMetareaderAccess(access="read") + + return FieldMetareaderAccess(access="none") + + +def get_default_field(elastic_type: ElasticType): + amcat_type = TYPEMAP_ES_TO_AMCAT.get(elastic_type) + if amcat_type is None: + raise ValueError(f"Invalid elastic type: {elastic_type}") + + return Field(type=amcat_type, elastic_type=elastic_type, metareader=get_default_metareader(amcat_type)) + + +# default fields when a new index is created +DEFAULT_FIELDS = { + "text": Field( + type="text", + elastic_type="text", + metareader=FieldMetareaderAccess(access="none"), + client_display=FieldClientDisplay(in_list=True), + ), + "title": Field( + type="text", + elastic_type="text", + metareader=FieldMetareaderAccess(access="read"), + client_display=FieldClientDisplay(in_list=True), + ), + "date": Field( + type="date", + elastic_type="date", + metareader=FieldMetareaderAccess(access="read"), + client_display=FieldClientDisplay(in_list=True), + ), + "url": Field( + type="keyword", + elastic_type="keyword", + metareader=FieldMetareaderAccess(access="read"), + client_display=FieldClientDisplay(in_list=True), + ), +} + + +def coerce_type(value: Any, elastic_type: ElasticType): + """ + Coerces values into the respective type in elastic + based on ES_MAPPINGS and elastic field types + """ + if elastic_type in ["text", "annotated_text", "binary", "match_only_text", "keyword", "constant_keyword", "wildcard"]: + return str(value) + if elastic_type in ["boolean"]: + return bool(value) + if elastic_type in ["long", "integer", "short", "byte", "unsigned_long"]: + return int(value) + if elastic_type in ["float", "half_float", "double", "scaled_float"]: + return float(value) + + # TODO: check coercion / validation for object, vector and geo types + return value + + +def create_elastic_fields(index: str, fields: dict[str, ElasticType]): + mapping: dict[str, Any] = {} + current_fields = {k: v for k, v in _get_index_fields(index)} + + for field, elastic_type in fields.items(): + if TYPEMAP_ES_TO_AMCAT.get(elastic_type) is None: + raise ValueError(f"Field type {elastic_type} not supported by AmCAT") + + current_type = current_fields.get(field) + if current_type is not None: + if current_type != elastic_type: + raise ValueError( + f"Field '{field}' already exists with type '{current_type}'. Cannot change type to '{elastic_type}'" + ) + continue + + mapping[field] = {"type": elastic_type} + + if elastic_type in ["date"]: + mapping[field]["format"] = "strict_date_optional_time" + + es().indices.create(index=index, mappings={"properties": mapping}) + + +def _fields_to_elastic(fields: dict[str, Field]) -> list[dict]: + return [{"field": field, "settings": settings.model_dump()} for field, settings in fields.items()] + + +def _fields_from_elastic( + fields: list[dict], +) -> dict[str, Field]: + return {fs["field"]: Field.model_validate(fs["settings"]) for fs in fields} + + +def set_fields(index: str, new_fields: dict[str, UpdateField] | dict[str, Field]): + """ + Set the fields settings for this index. Only updates fields that + already exist. type and elastic_type cannot be changed. + """ + + system_index = get_settings().system_index + fields = get_fields(index) + + for field, new_settings in new_fields.items(): + current = fields.get(field) + if current is None: + raise ValueError(f"Field {field} does not exist") + + if current.type != "text": + if new_settings.metareader and new_settings.metareader.access == "snippet": + raise ValueError(f"Field {field} is not of type text, cannot set metareader access to snippet") + + fields[field] = updateField(field=current, update=new_settings) + + es().update( + index=system_index, + id=index, + doc=dict(fields=_fields_to_elastic(fields)), + ) + + +def _get_index_fields(index: str) -> Iterator[tuple[str, ElasticType]]: + r = es().indices.get_mapping(index=index) + for k, v in r[index]["mappings"]["properties"].items(): + yield k, v.get("type", "object") + + +def get_fields(index: str) -> dict[str, Field]: + """ + Retrieve the fields settings for this index. Look for both the field settings in the system index, + and the field mappings in the index itself. If a field is not defined in the system index, return the + default settings for that field type and add it to the system index. This way, any elastic index can be imported + """ + fields: dict[str, Field] = {} + system_index = get_settings().system_index + + try: + d = es().get( + index=system_index, + id=index, + source_includes="fields", + ) + system_index_fields = _fields_from_elastic(d["_source"].get("fields", {})) + except NotFoundError: + system_index_fields = {} + + update_system_index = False + for field, elastic_type in _get_index_fields(index): + amcat_type = TYPEMAP_ES_TO_AMCAT.get(elastic_type) + + if amcat_type is None: + # skip over unsupported elastic fields. + # (TODO: also return warning to client?) + continue + + if field not in system_index_fields: + update_system_index = True + fields[field] = Field(type=amcat_type, elastic_type=elastic_type, metareader=get_default_metareader(amcat_type)) + else: + fields[field] = system_index_fields[field] + + if update_system_index: + es().update( + index=system_index, + id=index, + doc=dict(fields=_fields_to_elastic(fields)), + ) + + return fields + + +def field_values(index: str, field: str, size: int) -> list[str]: + """ + Get the values for a given field (e.g. to populate list of filter values on keyword field) + Results are sorted descending by document frequency + see: https://www.elastic.co/guide/en/elasticsearch/reference/7.4/search-aggregations-bucket-terms-aggregation.html#search-aggregations-bucket-terms-aggregation-order + + :param index: The index + :param field: The field name + :return: A list of values + """ + aggs = {"unique_values": {"terms": {"field": field, "size": size}}} + r = es().search(index=index, size=0, aggs=aggs) + return [x["key"] for x in r["aggregations"]["unique_values"]["buckets"]] + + +def field_stats(index: str, field: str) -> list[str]: + """ + :param index: The index + :param field: The field name + :return: A list of values + """ + aggs = {"facets": {"stats": {"field": field}}} + r = es().search(index=index, size=0, aggs=aggs) + return r["aggregations"]["facets"] + + +def update_by_query(index: str | list[str], script: str, query: dict, params: dict | None = None): + script_dict = dict(source=script, lang="painless", params=params or {}) + es().update_by_query(index=index, script=script_dict, **query) diff --git a/amcat4/index.py b/amcat4/index.py index 6cc6232..075ed02 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -33,7 +33,7 @@ """ import collections from enum import IntEnum -from typing import Any, Iterable, Iterator, Optional, Literal +from typing import Any, Iterable, Optional, Literal import hashlib import json @@ -44,45 +44,14 @@ # from amcat4.api.common import py2dict from amcat4.config import get_settings from amcat4.elastic import es -from amcat4.models import Field, FieldClientDisplay, SnippetParams, UpdateField, updateField, FieldMetareaderAccess - - -# The Field model has a type field as we use it in amcat, but we need to -# convert this to an elastic type. This is the mapping -ES_MAPPINGS = { - "long": {"type": "long"}, - "date": {"type": "date", "format": "strict_date_optional_time"}, - "double": {"type": "double"}, - "keyword": {"type": "keyword"}, - "url": {"type": "keyword"}, - "tag": {"type": "keyword"}, - "id": {"type": "keyword"}, - "text": {"type": "text"}, - "object": {"type": "object"}, - "geo_point": {"type": "geo_point"}, - "dense_vector": {"type": "dense_vector"}, -} - -DEFAULT_METAREADER = { - "long": FieldMetareaderAccess(access="read"), - "date": FieldMetareaderAccess(access="read"), - "double": FieldMetareaderAccess(access="read"), - "keyword": FieldMetareaderAccess(access="read"), - "url": FieldMetareaderAccess(access="read"), - "tag": FieldMetareaderAccess(access="read"), - "id": FieldMetareaderAccess(access="read"), - "text": FieldMetareaderAccess(access="none"), - "object": FieldMetareaderAccess(access="none"), - "geo_point": FieldMetareaderAccess(access="none"), - "dense_vector": FieldMetareaderAccess(access="none"), -} - -DEFAULT_INDEX_FIELDS = { - "text": Field(type="text", metareader=DEFAULT_METAREADER["text"], client_display=FieldClientDisplay(in_list=True)), - "title": Field(type="text", metareader=DEFAULT_METAREADER["text"], client_display=FieldClientDisplay(in_list=True)), - "date": Field(type="date", metareader=DEFAULT_METAREADER["date"], client_display=FieldClientDisplay(in_list=True)), - "url": Field(type="url", metareader=DEFAULT_METAREADER["url"], client_display=FieldClientDisplay(in_list=True)), -} +from amcat4.fields import ( + DEFAULT_FIELDS, + coerce_type, + create_elastic_fields, + get_fields, + set_fields, +) +from amcat4.models import ElasticType, Field class Role(IntEnum): @@ -171,9 +140,10 @@ def create_index( """ Create a new index in elasticsearch and register it with this AmCAT instance """ + default_mapping = {} - for field, settings in DEFAULT_INDEX_FIELDS.items(): - default_mapping[field] = ES_MAPPINGS[settings.type] + for field, settings in DEFAULT_FIELDS.items(): + default_mapping[field] = get_elastic_field_mapping(settings.type) es().indices.create(index=index, mappings={"properties": default_mapping}) register_index( @@ -183,7 +153,8 @@ def create_index( description=description, admin=admin, ) - set_fields(index, DEFAULT_INDEX_FIELDS) + + set_fields(index, DEFAULT_FIELDS) def register_index( @@ -289,72 +260,6 @@ def set_guest_role(index: str, guest_role: Optional[Role]): modify_index(index, guest_role=guest_role, remove_guest_role=(guest_role is None)) -def _fields_to_elastic(fields: dict[str, Field]) -> list[dict]: - return [{"field": field, "settings": settings.model_dump()} for field, settings in fields.items()] - - -def _fields_from_elastic( - fields: list[dict], -) -> dict[str, Field]: - return {fs["field"]: Field.model_validate(fs["settings"]) for fs in fields} - - -def set_fields(index: str, new_fields: dict[str, Field] | dict[str, UpdateField]): - """ - Set the fields settings for this index. - - Note that we're storing fields in two places. We keep all field settings in the system index. - But the index that contains the documents also needs to know what fields there are and - what their (elastic) types are. So whenever fields are added or the type is updated, we - also update the index mapping. - """ - - system_index = get_settings().system_index - try: - es().get(index=system_index, id=index, source_includes="fields") - fields = get_fields(index) - except NotFoundError: - fields = {} - - type_mappings = {} - - # Field type specific validation - for field, settings in new_fields.items(): - type = fields.get(field, settings).type - if type != "text": - if settings.metareader and settings.metareader.access == "snippet": - raise ValueError(f"Field {field} is not of type text, cannot set metareader access to snippet") - - for field, new_settings in new_fields.items(): - current = fields.get(field) - - if current is None: - # Create field - if new_settings.type is None: - raise ValueError(f"Field {field} does not yet exist, and to create a new field you need to specify a type") - - type_mappings[field] = ES_MAPPINGS[new_settings.type] - new_settings.metareader = new_settings.metareader or DEFAULT_METAREADER[new_settings.type] - fields[field] = Field(**new_settings.model_dump(exclude_none=True)) - else: - # Update field - # it is not possible to update elastic field types, but we can change amcat types (see ES_MAPPINGS) - if new_settings.type is not None: - if ES_MAPPINGS[current.type] != ES_MAPPINGS[new_settings.type]: - raise ValueError( - f"Field {field} already exists with type {current.type}, cannot change to {new_settings.type}" - ) - # set new field settings (amcat type, metareader, etc.) - fields[field] = updateField(field=current, update=new_settings) - - es().indices.put_mapping(index=index, properties=type_mappings) - es().update( - index=system_index, - id=index, - doc=dict(fields=_fields_to_elastic(fields)), - ) - - def modify_index( index: str, name: Optional[str] = None, @@ -369,6 +274,7 @@ def modify_index( guest_role=guest_role and guest_role.name, summary_field=summary_field, ) + if summary_field is not None: f = get_fields(index) if summary_field not in f: @@ -451,39 +357,6 @@ def get_global_role(email: str, only_es: bool = False) -> Role: return get_role(index=GLOBAL_ROLES, email=email) -def _get_index_fields(index: str) -> Iterator[tuple[str, str]]: - r = es().indices.get_mapping(index=index) - for k, v in r[index]["mappings"]["properties"].items(): - yield k, v.get("type", "object") - - -def get_fields(index: str) -> dict[str, Field]: - """ - Retrieve the fields settings for this index. Look for both the field settings in the system index, - and the field mappings in the index itself. If a field is not defined in the system index, return the - default settings for that field type. - """ - fields: dict[str, Field] = {} - - try: - d = es().get( - index=get_settings().system_index, - id=index, - source_includes="fields", - ) - index_fields = _fields_from_elastic(d["_source"].get("fields", {})) - except NotFoundError: - index_fields = {} - - for field, fieldtype in _get_index_fields(index): - if field not in index_fields: - fields[field] = Field(type=fieldtype, metareader=DEFAULT_METAREADER[fieldtype]) - else: - fields[field] = index_fields[field] - - return fields - - def list_users(index: str) -> dict[str, Role]: """ " List all users and their roles on the given index @@ -509,32 +382,6 @@ def delete_user(email: str) -> None: set_role(ix.id, email, None) -def coerce_type_to_elastic(value, ftype): - """ - Coerces values into the respective type in elastic - based on ES_MAPPINGS and elastic field types - """ - # TODO: aks Wouter what this is based on, and why it doesn't - # actually seem to be based on ES_MAPPINGS - if ftype in ["keyword", "constant_keyword", "wildcard", "url", "tag", "text"]: - value = str(value) - elif ftype in [ - "long", - "short", - "byte", - "double", - "float", - "half_float", - "unsigned_long", - ]: - value = float(value) - elif ftype in ["integer"]: - value = int(value) - elif ftype == "boolean": - value = bool(value) - return value - - def _get_hash(document: dict) -> str: """ Get the hash for a document @@ -545,7 +392,7 @@ def _get_hash(document: dict) -> str: return m.hexdigest() -def upload_documents(index: str, documents: list[dict[str, Any]], fields: dict[str, UpdateField] | None = None) -> None: +def upload_documents(index: str, documents: list[dict[str, Any]], types: dict[str, ElasticType] | None = None) -> None: """ Upload documents to this index @@ -554,6 +401,9 @@ def upload_documents(index: str, documents: list[dict[str, Any]], fields: dict[s :param fields: A mapping of fieldname:UpdateField for field types """ + if types: + create_elastic_fields(index, types) + def es_actions(index, documents): field_types = get_fields(index) for document in documents: @@ -562,14 +412,11 @@ def es_actions(index, documents): continue if key not in field_types: raise ValueError(f"The type for field {key} is not yet specified") - document[key] = coerce_type_to_elastic(document[key], field_types[key].type) + document[key] = coerce_type(document[key], field_types[key].elastic_type) if "_id" not in document: document["_id"] = _get_hash(document) yield {"_index": index, **document} - if fields: - set_fields(index, fields) - actions = list(es_actions(index, documents)) elasticsearch.helpers.bulk(es(), actions) @@ -607,32 +454,6 @@ def delete_document(index: str, doc_id: str): es().delete(index=index, id=doc_id) -def get_field_values(index: str, field: str, size: int) -> list[str]: - """ - Get the values for a given field (e.g. to populate list of filter values on keyword field) - Results are sorted descending by document frequency - see: https://www.elastic.co/guide/en/elasticsearch/reference/7.4/search-aggregations-bucket-terms-aggregation.html#search-aggregations-bucket-terms-aggregation-order - - :param index: The index - :param field: The field name - :return: A list of values - """ - aggs = {"unique_values": {"terms": {"field": field, "size": size}}} - r = es().search(index=index, size=0, aggs=aggs) - return [x["key"] for x in r["aggregations"]["unique_values"]["buckets"]] - - -def get_field_stats(index: str, field: str) -> list[str]: - """ - :param index: The index - :param field: The field name - :return: A list of values - """ - aggs = {"facets": {"stats": {"field": field}}} - r = es().search(index=index, size=0, aggs=aggs) - return r["aggregations"]["facets"] - - def update_by_query(index: str | list[str], script: str, query: dict, params: dict | None = None): script_dict = dict(source=script, lang="painless", params=params or {}) es().update_by_query(index=index, script=script_dict, **query) diff --git a/amcat4/models.py b/amcat4/models.py index 798cd05..5bc5a2e 100644 --- a/amcat4/models.py +++ b/amcat4/models.py @@ -1,6 +1,34 @@ import pydantic from pydantic import BaseModel -from typing import Annotated, Literal +from typing import Annotated, Literal, NewType + + +AmcatType = Literal["text", "date", "boolean", "keyword", "number", "object", "vector", "geo"] +ElasticType = Literal[ + "text", + "annotated_text", + "binary", + "match_only_text", + "date", + "boolean", + "keyword", + "constant_keyword", + "wildcard", + "integer", + "byte", + "short", + "long", + "unsigned_long", + "float", + "half_float", + "double", + "scaled_float", + "object", + "flattened", + "nested", + "dense_vector", + "geo_point", +] class SnippetParams(BaseModel): @@ -10,9 +38,9 @@ class SnippetParams(BaseModel): the first [nomatch_chars] of the field. """ - nomatch_chars: Annotated[int, pydantic.Field(ge=1)] = 1 + nomatch_chars: Annotated[int, pydantic.Field(ge=1)] = 100 max_matches: Annotated[int, pydantic.Field(ge=0)] = 0 - match_chars: Annotated[int, pydantic.Field(ge=1)] = 1 + match_chars: Annotated[int, pydantic.Field(ge=1)] = 50 class FieldClientDisplay(BaseModel): @@ -32,15 +60,23 @@ class FieldMetareaderAccess(BaseModel): class Field(BaseModel): """Settings for a field.""" - type: str + type: AmcatType + elastic_type: ElasticType metareader: FieldMetareaderAccess = FieldMetareaderAccess() client_display: FieldClientDisplay = FieldClientDisplay() +class CreateField(BaseModel): + """Model for creating a field""" + + elastic_type: ElasticType + metareader: FieldMetareaderAccess | None = None + client_display: FieldClientDisplay | None = None + + class UpdateField(BaseModel): """Model for updating a field""" - type: str | None = None metareader: FieldMetareaderAccess | None = None client_display: FieldClientDisplay | None = None diff --git a/tests/test_elastic.py b/tests/test_elastic.py index e00fedc..910e7a0 100644 --- a/tests/test_elastic.py +++ b/tests/test_elastic.py @@ -4,12 +4,10 @@ refresh_index, upload_documents, get_document, - set_fields, - get_fields, update_document, update_tag_by_query, - get_field_values, ) +from amcat4.fields import set_fields, get_fields, get_field_values from amcat4.query import query_documents from tests.conftest import upload From 76c30a3c34e51a8c6b74a575dbb13c5f3f09d22b Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Tue, 30 Jan 2024 11:00:05 +0100 Subject: [PATCH 19/80] fixed add fields bug --- amcat4/fields.py | 2 +- amcat4/index.py | 10 +++++++--- amcat4/query.py | 1 + 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/amcat4/fields.py b/amcat4/fields.py index 1316eb8..139579e 100644 --- a/amcat4/fields.py +++ b/amcat4/fields.py @@ -146,7 +146,7 @@ def create_elastic_fields(index: str, fields: dict[str, ElasticType]): if elastic_type in ["date"]: mapping[field]["format"] = "strict_date_optional_time" - es().indices.create(index=index, mappings={"properties": mapping}) + es().indices.put_mapping(index=index, properties=mapping) def _fields_to_elastic(fields: dict[str, Field]) -> list[dict]: diff --git a/amcat4/index.py b/amcat4/index.py index 075ed02..119ca13 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -51,7 +51,7 @@ get_fields, set_fields, ) -from amcat4.models import ElasticType, Field +from amcat4.models import ElasticType class Role(IntEnum): @@ -140,12 +140,14 @@ def create_index( """ Create a new index in elasticsearch and register it with this AmCAT instance """ + delete_index(index, ignore_missing=True) default_mapping = {} for field, settings in DEFAULT_FIELDS.items(): - default_mapping[field] = get_elastic_field_mapping(settings.type) + default_mapping[field] = {"type": settings.elastic_type} es().indices.create(index=index, mappings={"properties": default_mapping}) + register_index( index, guest_role=guest_role, @@ -173,6 +175,7 @@ def register_index( if es().exists(index=system_index, id=index): raise ValueError(f"Index {index} is already registered") roles = [dict(email=admin, role="ADMIN")] if admin else [] + es().index( index=system_index, id=index, @@ -192,9 +195,10 @@ def delete_index(index: str, ignore_missing=False) -> None: :param index: The name of the index :param ignore_missing: If True, do not throw exception if index does not exist """ - deregister_index(index, ignore_missing=ignore_missing) _es = es().options(ignore_status=404) if ignore_missing else es() _es.indices.delete(index=index) + print(_es.indices.get_alias(index="*")) + deregister_index(index, ignore_missing=ignore_missing) def deregister_index(index: str, ignore_missing=False) -> None: diff --git a/amcat4/query.py b/amcat4/query.py index 0ed9ebd..50d7b3b 100644 --- a/amcat4/query.py +++ b/amcat4/query.py @@ -158,6 +158,7 @@ def query_documents( :param kwargs: Additional elements passed to Elasticsearch.search() :return: a QueryResult, or None if there is not scroll result anymore """ + print(index, fields, queries) if fields is not None and not isinstance(fields, list): raise ValueError("fields should be a list") From 74a54734006bc02d0b5a26b90f23bada474c17e5 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Thu, 1 Feb 2024 12:44:50 +0100 Subject: [PATCH 20/80] aggregation stuff --- amcat4/aggregate.py | 17 +++++++++++++++++ amcat4/fields.py | 2 +- amcat4/models.py | 1 + 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/amcat4/aggregate.py b/amcat4/aggregate.py index ab768ef..bd9ab7d 100644 --- a/amcat4/aggregate.py +++ b/amcat4/aggregate.py @@ -45,6 +45,21 @@ def query(self): if self.ftype == "date": if m := interval_mapping(self.interval): return {self.name: {"terms": {"field": m.fieldname(self.field)}}} + # KW: auto_date_histogram is not supported within composite. + # Either we let client handle auto determining interval, or we drop composite + # (dropping composite matter relates to comment by WvA below) + # if self.interval == "auto": + # return { + # self.name: { + # "auto_date_histogram": { + # "field": self.field, + # "buckets": 30, + # "minimum_interval": "day", + # "format": "yyyy-MM-dd", + # } + # } + # } + return {self.name: {"date_histogram": {"field": self.field, "calendar_interval": self.interval}}} else: return {self.name: {"histogram": {"field": self.field, "interval": self.interval}}} @@ -146,6 +161,7 @@ def _elastic_aggregate( if aggregations: aggr["aggs"]["aggregations"] = aggregation_dsl(aggregations) kargs = {} + print(aggr) if filters or queries: q = build_body(queries=queries, filters=filters) kargs["query"] = q["query"] @@ -197,6 +213,7 @@ def _aggregate_results( # Run an aggregation with one or more axes sources = [axis.query() for axis in axes] runtime_mappings = _combine_mappings(axis.runtime_mappings() for axis in axes) + for bucket in _elastic_aggregate(index, sources, queries, filters, aggregations, runtime_mappings): row = tuple(axis.get_value(bucket["key"]) for axis in axes) row += (bucket["doc_count"],) diff --git a/amcat4/fields.py b/amcat4/fields.py index 139579e..65ba4a1 100644 --- a/amcat4/fields.py +++ b/amcat4/fields.py @@ -96,7 +96,7 @@ def get_default_field(elastic_type: ElasticType): type="date", elastic_type="date", metareader=FieldMetareaderAccess(access="read"), - client_display=FieldClientDisplay(in_list=True), + client_display=FieldClientDisplay(in_list=True, in_list_summary=True), ), "url": Field( type="keyword", diff --git a/amcat4/models.py b/amcat4/models.py index 5bc5a2e..91f5961 100644 --- a/amcat4/models.py +++ b/amcat4/models.py @@ -47,6 +47,7 @@ class FieldClientDisplay(BaseModel): """Client display settings for a specific field.""" in_list: bool = False + in_list_summary: bool = False in_document: bool = True From 75118e00d9ae387524d1219f914419a8d87954ee Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Fri, 2 Feb 2024 18:41:14 +0100 Subject: [PATCH 21/80] added pagination to aggregations --- amcat4/aggregate.py | 118 +++++++++++++++++++++++++++++++++----------- amcat4/api/query.py | 4 ++ 2 files changed, 92 insertions(+), 30 deletions(-) diff --git a/amcat4/aggregate.py b/amcat4/aggregate.py index bd9ab7d..1a2b92f 100644 --- a/amcat4/aggregate.py +++ b/amcat4/aggregate.py @@ -2,7 +2,8 @@ Aggregate queries """ from datetime import datetime -from typing import Mapping, Iterable, Union, Tuple, Sequence, List, Dict +from itertools import islice +from typing import Any, Mapping, Iterable, Union, Tuple, Sequence, List, Dict from amcat4.date_mappings import interval_mapping from amcat4.elastic import es @@ -114,11 +115,19 @@ def aggregation_dsl(aggregations: Iterable[Aggregation]) -> dict: class AggregateResult: - def __init__(self, axes: Sequence[Axis], aggregations: List[Aggregation], data: List[tuple], count_column: str = "n"): + def __init__( + self, + axes: Sequence[Axis], + aggregations: List[Aggregation], + data: List[tuple], + count_column: str = "n", + after: dict | None = None, + ): self.axes = axes self.data = data self.aggregations = aggregations self.count_column = count_column + self.after = after def as_dicts(self) -> Iterable[dict]: """Return the results as a sequence of {axis1, ..., n} dicts""" @@ -144,27 +153,29 @@ def _bare_aggregate(index: str | list[str], queries, filters, aggregations: Sequ def _elastic_aggregate( index: str | list[str], sources, + axes, queries, filters, aggregations: list[Aggregation], runtime_mappings: dict[str, Mapping] | None = None, after_key=None, -) -> Iterable[dict]: +) -> Tuple[list, dict | None]: """ Recursively get all buckets from a composite query. Yields 'buckets' consisting of {key: {axis: value}, doc_count: } """ # [WvA] Not sure if we should get all results ourselves or expose the 'after' pagination. # This might get us in trouble if someone e.g. aggregates on url or day for a large corpus - after = {"after": after_key} if after_key else {} + after = {"after": after_key} if after_key is not None and len(after_key) > 0 else {} aggr: Dict[str, Dict[str, dict]] = {"aggs": {"composite": dict(sources=sources, **after)}} if aggregations: aggr["aggs"]["aggregations"] = aggregation_dsl(aggregations) kargs = {} - print(aggr) + if filters or queries: q = build_body(queries=queries, filters=filters) kargs["query"] = q["query"] + result = es().search( index=index if isinstance(index, str) else ",".join(index), size=0, @@ -174,12 +185,19 @@ def _elastic_aggregate( ) if failure := result.get("_shards", {}).get("failures"): raise Exception(f"Error on running aggregate search: {failure}") - yield from result["aggregations"]["aggs"]["buckets"] + + buckets = result["aggregations"]["aggs"]["buckets"] after_key = result["aggregations"]["aggs"].get("after_key") - if after_key: - yield from _elastic_aggregate( - index, sources, queries, filters, aggregations, runtime_mappings=runtime_mappings, after_key=after_key - ) + + rows = [] + for bucket in buckets: + row = tuple(axis.get_value(bucket["key"]) for axis in axes) + row += (bucket["doc_count"],) + if aggregations: + row += tuple(a.get_value(bucket) for a in aggregations) + rows.append(row) + + return rows, after_key def _aggregate_results( @@ -188,38 +206,68 @@ def _aggregate_results( queries: dict[str, str] | None, filters: dict[str, FilterSpec] | None, aggregations: List[Aggregation], -) -> Iterable[tuple]: + after: dict[str, Any] | None = None, +): if not axes: + # Pagh 1 # No axes, so return aggregations (or total count) only if aggregations: count, results = _bare_aggregate(index, queries, filters, aggregations) - yield (count,) + tuple(a.get_value(results) for a in aggregations) + rows = [(count,) + tuple(a.get_value(results) for a in aggregations)] else: result = es().count( index=index if isinstance(index, str) else ",".join(index), **build_body(queries=queries, filters=filters) ) - yield result["count"], + rows = [(result["count"],)] + yield rows, None + elif any(ax.field == "_query" for ax in axes): + # Path 2 + # We cannot run the aggregation for multiple queries at once, so we loop over queries + # and recursively call _aggregate_results with one query at a time (which then uses path 3). if queries is None: raise ValueError("Queries must be specified when aggregating by query") # Strip off _query axis and run separate aggregation for each query i = [ax.field for ax in axes].index("_query") _axes = axes[:i] + axes[(i + 1) :] - for label, query in queries.items(): - for result_tuple in _aggregate_results(index, _axes, {label: query}, filters, aggregations): + query_items = list(queries.items()) + for label, query in query_items: + last_query = label == query_items[-1][0] + if after is not None and "_query" in after: + # after is a dict with the aggregation values from which to continue + # pagination. Since we loop over queries, we add the _query value. + # Then after continuing from the right query, we remove this _query + # key so that the after dict is as elastic expects it + after_query = after.pop("_query", None) + if after_query != label: + continue + + for rows, after in _aggregate_results(index, _axes, {label: query}, filters, aggregations, after=after): # insert label into the right position on the result tuple - yield result_tuple[:i] + (label,) + result_tuple[i:] + rows = [result_tuple[:i] + (label,) + result_tuple[i:] for result_tuple in rows] + + if after is None: + # if there are no buckets left for this query, we check if this is the last query. + # If not, we need to return the _query value to ensure pagination continues from this query + if not last_query: + after = {"_query": label} + else: + # if there are buckets left, we add the _query value to ensure pagination continues from this query + after["_query"] = label + yield rows, after + else: - # Run an aggregation with one or more axes + # Path 3 + # Run an aggregation with one or more axes. If after is not None, we continue from there. sources = [axis.query() for axis in axes] runtime_mappings = _combine_mappings(axis.runtime_mappings() for axis in axes) - for bucket in _elastic_aggregate(index, sources, queries, filters, aggregations, runtime_mappings): - row = tuple(axis.get_value(bucket["key"]) for axis in axes) - row += (bucket["doc_count"],) - if aggregations: - row += tuple(a.get_value(bucket) for a in aggregations) - yield row + rows, after = _elastic_aggregate(index, sources, axes, queries, filters, aggregations, runtime_mappings, after) + yield rows, after + + if after is not None: + for rows, after in _aggregate_results(index, axes, queries, filters, aggregations, after): + yield rows, after def query_aggregate( @@ -229,6 +277,7 @@ def query_aggregate( *, queries: dict[str, str] | None = None, filters: dict[str, FilterSpec] | None = None, + after: dict[str, Any] | None = None, ) -> AggregateResult: """ Conduct an aggregate query. @@ -266,10 +315,19 @@ def query_aggregate( aggregations = [] for aggregation in aggregations: aggregation.ftype = all_fields[aggregation.field].type - data = list(_aggregate_results(index, axes, queries, filters, aggregations)) - return AggregateResult( - axes, - aggregations, - data, - count_column="n", - ) + + # We get the rows in sets of queries * buckets, and if there are queries or buckets left, + # the last_after value serves as a pagination cursor. Once we have > [stop_after] rows, + # we return the data and the last_after cursor. If the user needs to collect the rest, + # they need to paginate + stop_after = 1000 + gen = _aggregate_results(index, axes, queries, filters, aggregations, after) + data = list() + last_after = None + for rows, after in gen: + data += rows + last_after = after + if len(data) > stop_after: + gen.close() + + return AggregateResult(axes, aggregations, data, count_column="n", after=last_after) diff --git a/amcat4/api/query.py b/amcat4/api/query.py index b2e1913..d68f362 100644 --- a/amcat4/api/query.py +++ b/amcat4/api/query.py @@ -1,5 +1,6 @@ """API Endpoints for querying.""" +from re import search from typing import Annotated, Dict, List, Optional, Any, Union, Iterable, Literal from fastapi import APIRouter, HTTPException, status, Depends, Response, Body @@ -287,6 +288,7 @@ def query_aggregate_post( "which can be either a value, a list of values, or a FilterSpec dict", ), ] = None, + after: Annotated[dict[str, Any] | None, Body(description="After cursor for pagination")] = None, user: str = Depends(authenticated_user), ): """ @@ -313,10 +315,12 @@ def query_aggregate_post( results = aggregate.query_aggregate( indices, _axes, _aggregations, queries=_standardize_queries(queries), filters=_standardize_filters(filters) ) + return { "meta": { "axes": [axis.asdict() for axis in results.axes], "aggregations": [a.asdict() for a in results.aggregations], + "after": results.after, }, "data": list(results.as_dicts()), } From 023b437e1ecfd312a493ffd7756fda217cba3353 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Sat, 10 Feb 2024 11:48:37 +0100 Subject: [PATCH 22/80] fixed some bugs in config_amcat --- .env.example | 2 +- amcat4/__main__.py | 26 +++++++++++++------------- amcat4/config.py | 3 ++- requirements.txt | 42 ++++++++++++++++++++++++++++++++++++++++++ setup.py | 1 + 5 files changed, 59 insertions(+), 15 deletions(-) create mode 100644 requirements.txt diff --git a/.env.example b/.env.example index 2badb13..c3e5e9c 100644 --- a/.env.example +++ b/.env.example @@ -19,7 +19,7 @@ amcat4_elastic_verify_ssl=False amcat4_auth=no_auth # Middlecat server to trust as ID provider -amcat4_middlecat_url=https://middlecat.up.railway.app +amcat4_middlecat_url=https://middlecat.net # Email address for a hardcoded admin email (useful for setup and recovery) #amcat4_admin_email= diff --git a/amcat4/__main__.py b/amcat4/__main__.py index b9eb183..7c21c39 100644 --- a/amcat4/__main__.py +++ b/amcat4/__main__.py @@ -1,6 +1,7 @@ """ AmCAT4 REST API """ + import argparse import csv import io @@ -130,7 +131,7 @@ def migrate_index(_args) -> None: def base_env(): return dict( amcat4_secret_key=secrets.token_hex(nbytes=32), - amcat4_middlecat_url="https://middlecat.up.netlify.app", + amcat4_middlecat_url="https://middlecat.net", ) @@ -181,28 +182,27 @@ def list_users(_args): def config_amcat(args): settings = get_settings() - settings_dict = settings.model_dump() + settings_fields = settings.model_fields # Not a useful entry in an actual env_file - env_file_location = settings_dict.pop("env_file") - print(f"Reading/writing settings from {env_file_location}") - for fieldname in settings.model_fields_set: - if fieldname not in settings_dict: + print(f"Reading/writing settings from {settings.env_file}") + for fieldname, fieldinfo in settings.model_fields.items(): + if fieldname == "env_file": continue - fieldinfo = settings_dict[fieldname] + validation_function = AuthOptions.validate if fieldname == "auth" else None value = getattr(settings, fieldname) value = menu(fieldname, fieldinfo, value, validation_function=validation_function) if value is ABORTED: return if value is not UNCHANGED: - settings_dict[fieldname] = value + setattr(settings, fieldname, value) - with env_file_location.open("w") as f: - for fieldname, value in settings_dict.items(): - fieldinfo = settings_dict[fieldname] + with settings.env_file.open("w") as f: + for fieldname, fieldinfo in settings.model_fields.items(): + value = getattr(settings, fieldname) if doc := fieldinfo.description: f.write(f"# {doc}\n") - if _isenum(fieldinfo): + if _isenum(fieldinfo) and fieldinfo.annotation: f.write("# Valid options:\n") for option in fieldinfo.annotation: doc = option.__doc__.replace("\n", " ") @@ -212,7 +212,7 @@ def config_amcat(args): else: f.write(f"amcat4_{fieldname}={value}\n\n") os.chmod(".env", 0o600) - print(f"*** Written {bold('.env')} file to {env_file_location} ***") + print(f"*** Written {bold('.env')} file to {settings.env_file} ***") def bold(x): diff --git a/amcat4/config.py b/amcat4/config.py index 6d339f2..b560f74 100644 --- a/amcat4/config.py +++ b/amcat4/config.py @@ -6,6 +6,7 @@ - A .env file, either in the current working directory or in a location specified by the AMCAT4_CONFIG_FILE environment variable """ + import functools from enum import Enum from pathlib import Path @@ -100,7 +101,7 @@ class Settings(BaseSettings): Field( description="Middlecat server to trust as ID provider", ), - ] = "https://middlecat.up.railway.app" + ] = "https://middlecat.net" admin_email: Annotated[ str | None, diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..75ad82a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,42 @@ +anyio==3.7.1 +asgiref==3.6.0 +attrs==19.3.0 +Authlib==1.2.1 +certifi==2023.7.22 +cffi==1.15.1 +charset-normalizer==3.2.0 +class-doc==0.2.0b0 +click==8.1.6 +cryptography==41.0.2 +dnspython==2.4.1 +elastic-transport==8.4.0 +elasticsearch==8.8.2 +email-validator==1.3.1 +exceptiongroup==1.1.2 +fastapi==0.78.0 +h11==0.14.0 +httptools==0.6.0 +idna==3.4 +itsdangerous==2.1.2 +Jinja2==3.1.2 +MarkupSafe==2.1.3 +more-itertools==7.2.0 +orjson==3.9.2 +pycparser==2.21 +pydantic==1.9.2 +pydantic-settings==0.2.5 +python-dotenv==1.0.0 +python-multipart==0.0.5 +PyYAML==5.4.1 +requests==2.31.0 +six==1.16.0 +sniffio==1.3.0 +starlette==0.19.1 +tomlkit==0.5.11 +typing-extensions==3.10.0.2 +ujson==5.8.0 +urllib3==1.26.16 +uvicorn==0.17.6 +uvloop==0.17.0 +watchgod==0.8.2 +websockets==11.0.3 diff --git a/setup.py b/setup.py index c9b1d0b..3e4360a 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ "uvicorn", "requests", "class_doc", + "mypy", ], extras_require={"dev": ["pytest", "mypy", "flake8", "responses", "pre-commit", "types-requests"]}, entry_points={"console_scripts": ["amcat4 = amcat4.__main__:main"]}, From 601ce36b4117666650a6aea085b8a9dd893db530 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Wed, 14 Feb 2024 18:13:29 +0100 Subject: [PATCH 23/80] standardizing how role is stored --- amcat4/api/index.py | 19 +++++++++++-------- amcat4/index.py | 25 +++++++++++++------------ 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/amcat4/api/index.py b/amcat4/api/index.py index 2a73721..646d01a 100644 --- a/amcat4/api/index.py +++ b/amcat4/api/index.py @@ -1,4 +1,5 @@ """API Endpoints for document and index management.""" + from http import HTTPStatus from typing import Annotated, Any, Literal @@ -29,8 +30,9 @@ def index_list(current_user: str = Depends(authenticated_user)): def index_to_dict(ix: index.Index) -> dict: ix_dict = ix._asdict() - ix_dict["guest_role"] = ix_dict["guest_role"] and ix_dict["guest_role"].name - del ix_dict["roles"] + guest_role_int = ix_dict.get("guest_role", 0) + + ix_dict = dict(id=ix_dict["id"], name=ix_dict["name"], guest_role=index.Role(guest_role_int).name) return ix_dict return [index_to_dict(ix) for ix in index.list_known_indices(current_user)] @@ -40,8 +42,8 @@ class NewIndex(BaseModel): """Form to create a new index.""" id: str - guest_role: RoleType | None = None name: str | None = None + guest_role: RoleType | None = None description: str | None = None @@ -75,7 +77,6 @@ class ChangeIndex(BaseModel): guest_role: Literal["ADMIN", "WRITER", "READER", "METAREADER", "NONE"] | None = "NONE" name: str | None = None description: str | None = None - summary_field: str | None = None @app_index.put("/{ix}") @@ -88,20 +89,20 @@ def modify_index(ix: str, data: ChangeIndex, user: str = Depends(authenticated_u User needs admin rights on the index """ check_role(user, index.Role.ADMIN, ix) - guest_role, remove_guest_role = None, False + guest_role, remove_guest_role = index.Role.NONE, False if data.guest_role: - role = data.guest_role.upper() + role = data.guest_role if role == "NONE": remove_guest_role = True else: guest_role = index.Role[role] + index.modify_index( ix, name=data.name, description=data.description, guest_role=guest_role, remove_guest_role=remove_guest_role, - summary_field=data.summary_field, ) refresh_system_index() @@ -115,7 +116,9 @@ def view_index(ix: str, user: str = Depends(authenticated_user)): role = check_role(user, index.Role.METAREADER, ix, required_global_role=index.Role.WRITER) d = index.get_index(ix)._asdict() d["user_role"] = role.name - d["guest_role"] = d.get("guest_role", index.Role.NONE.name) + d["guest_role"] = index.Role(d.get("guest_role", 0)).name + d["description"] = d.get("description", "") or "" + d["name"] = d.get("name", "") or "" return d except index.IndexDoesNotExist: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Index {ix} does not exist") diff --git a/amcat4/index.py b/amcat4/index.py index 119ca13..9ed0b16 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -31,6 +31,7 @@ - We define the mappings (field types) based on existing elasticsearch mappings, but use field metadata to define specific fields. """ + import collections from enum import IntEnum from typing import Any, Iterable, Optional, Literal @@ -150,9 +151,9 @@ def create_index( register_index( index, - guest_role=guest_role, - name=name, - description=description, + guest_role=guest_role or Role.NONE, + name=name or index, + description=description or "", admin=admin, ) @@ -176,6 +177,11 @@ def register_index( raise ValueError(f"Index {index} is already registered") roles = [dict(email=admin, role="ADMIN")] if admin else [] + if guest_role is not None: + guest_role_int = guest_role.value + else: + guest_role_int = Role.NONE.value + es().index( index=system_index, id=index, @@ -183,7 +189,7 @@ def register_index( name=(name or index), roles=roles, description=description, - guest_role=guest_role and guest_role.name, + guest_role=guest_role_int, ), ) refresh_index(system_index) @@ -272,22 +278,17 @@ def modify_index( remove_guest_role=False, summary_field=None, ): + doc = dict( name=name, description=description, - guest_role=guest_role and guest_role.name, + guest_role=guest_role and guest_role.value, summary_field=summary_field, ) - if summary_field is not None: - f = get_fields(index) - if summary_field not in f: - raise ValueError(f"Summary field {summary_field} does not exist!") - if f[summary_field].type not in ["date", "keyword", "tag"]: - raise ValueError(f"Summary field {summary_field} should be date, keyword or tag, not {f[summary_field].type}!") doc = {x: v for (x, v) in doc.items() if v} if remove_guest_role: - doc["guest_role"] = Role.NONE + doc["guest_role"] = Role.NONE.value if doc: es().update(index=get_settings().system_index, id=index, doc=doc) From ac2c0e721b8e141054ba316d9a0500035fc7825b Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Thu, 15 Feb 2024 10:27:10 +0100 Subject: [PATCH 24/80] some error handling --- amcat4/api/users.py | 5 +++-- amcat4/fields.py | 3 +-- amcat4/index.py | 25 ++++++++++++++++++++++--- 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/amcat4/api/users.py b/amcat4/api/users.py index ff49540..511a0ec 100644 --- a/amcat4/api/users.py +++ b/amcat4/api/users.py @@ -4,6 +4,7 @@ AmCAT4 can use either Basic or Token-based authentication. A client can request a token with basic authentication and store that token for future requests. """ + from typing import Literal, Optional from importlib.metadata import version @@ -14,7 +15,7 @@ from amcat4 import index from amcat4.api.auth import authenticated_user, authenticated_admin, check_global_role from amcat4.config import get_settings, validate_settings -from amcat4.index import Role, set_global_role, get_global_role +from amcat4.index import Role, set_global_role, get_global_role, user_exists app_users = APIRouter(tags=["users"]) @@ -38,7 +39,7 @@ class ChangeUserForm(BaseModel): @app_users.post("/users", status_code=status.HTTP_201_CREATED) def create_user(new_user: UserForm, _=Depends(authenticated_admin)): """Create a new user.""" - if get_global_role(new_user.email, only_es=True) is not None: + if user_exists(new_user.email): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"User {new_user.email} already exists", diff --git a/amcat4/fields.py b/amcat4/fields.py index 65ba4a1..4748b81 100644 --- a/amcat4/fields.py +++ b/amcat4/fields.py @@ -12,7 +12,6 @@ system index """ - from typing import Any, Iterator @@ -100,7 +99,7 @@ def get_default_field(elastic_type: ElasticType): ), "url": Field( type="keyword", - elastic_type="keyword", + elastic_type="wildcard", metareader=FieldMetareaderAccess(access="read"), client_display=FieldClientDisplay(in_list=True), ), diff --git a/amcat4/index.py b/amcat4/index.py index 9ed0b16..9375aeb 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -141,12 +141,16 @@ def create_index( """ Create a new index in elasticsearch and register it with this AmCAT instance """ - delete_index(index, ignore_missing=True) - default_mapping = {} for field, settings in DEFAULT_FIELDS.items(): default_mapping[field] = {"type": settings.elastic_type} + try: + get_index(index) + raise ValueError(f'Index "{index}" already exists') + except IndexDoesNotExist: + pass + es().indices.create(index=index, mappings={"properties": default_mapping}) register_index( @@ -203,7 +207,6 @@ def delete_index(index: str, ignore_missing=False) -> None: """ _es = es().options(ignore_status=404) if ignore_missing else es() _es.indices.delete(index=index) - print(_es.indices.get_alias(index="*")) deregister_index(index, ignore_missing=ignore_missing) @@ -307,6 +310,22 @@ def remove_global_role(email: str): remove_role(index=GLOBAL_ROLES, email=email) +def user_exists(email: str, index: str = GLOBAL_ROLES) -> bool: + """ + Check if a user exists on server (GLOBAL_ROLES) or in a specific index + """ + try: + doc = es().get( + index=get_settings().system_index, + id=index, + source_includes=["roles", "guest_role"], + ) + except NotFoundError: + raise IndexDoesNotExist(f"Index {index} does not exist or is not registered") + roles_dict = _roles_from_elastic(doc["_source"].get("roles", [])) + return email in roles_dict + + def get_role(index: str, email: str) -> Role: """ Retrieve the role of this user on this index, or the guest role if user has no role From 56bc862236c3cb8f67d07c3361fa838d5e0c955d Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Thu, 15 Feb 2024 17:56:20 +0100 Subject: [PATCH 25/80] fixed bug in aggregate buckets --- amcat4/aggregate.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/amcat4/aggregate.py b/amcat4/aggregate.py index 1a2b92f..afad219 100644 --- a/amcat4/aggregate.py +++ b/amcat4/aggregate.py @@ -1,6 +1,7 @@ """ Aggregate queries """ + from datetime import datetime from itertools import islice from typing import Any, Mapping, Iterable, Union, Tuple, Sequence, List, Dict @@ -208,8 +209,9 @@ def _aggregate_results( aggregations: List[Aggregation], after: dict[str, Any] | None = None, ): - if not axes: - # Pagh 1 + + if not axes or len(axes) == 0: + # Path 1 # No axes, so return aggregations (or total count) only if aggregations: count, results = _bare_aggregate(index, queries, filters, aggregations) @@ -233,6 +235,7 @@ def _aggregate_results( query_items = list(queries.items()) for label, query in query_items: last_query = label == query_items[-1][0] + if after is not None and "_query" in after: # after is a dict with the aggregation values from which to continue # pagination. Since we loop over queries, we add the _query value. @@ -242,19 +245,19 @@ def _aggregate_results( if after_query != label: continue - for rows, after in _aggregate_results(index, _axes, {label: query}, filters, aggregations, after=after): + for rows, after_buckets in _aggregate_results(index, _axes, {label: query}, filters, aggregations, after=after): # insert label into the right position on the result tuple rows = [result_tuple[:i] + (label,) + result_tuple[i:] for result_tuple in rows] - if after is None: + if after_buckets is None: # if there are no buckets left for this query, we check if this is the last query. # If not, we need to return the _query value to ensure pagination continues from this query if not last_query: - after = {"_query": label} + after_buckets = {"_query": label} else: # if there are buckets left, we add the _query value to ensure pagination continues from this query - after["_query"] = label - yield rows, after + after_buckets["_query"] = label + yield rows, after_buckets else: # Path 3 From cd25315cfdd33d26ede7487022fa31003eb2223f Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Fri, 1 Mar 2024 11:31:39 +0100 Subject: [PATCH 26/80] make all client settings free form --- amcat4/fields.py | 30 +++++------------------------- amcat4/models.py | 19 ++++++------------- amcat4/query.py | 3 ++- 3 files changed, 13 insertions(+), 39 deletions(-) diff --git a/amcat4/fields.py b/amcat4/fields.py index 4748b81..5a3ba52 100644 --- a/amcat4/fields.py +++ b/amcat4/fields.py @@ -20,7 +20,7 @@ # from amcat4.api.common import py2dict from amcat4.config import get_settings from amcat4.elastic import es -from amcat4.models import AmcatType, ElasticType, Field, FieldClientDisplay, UpdateField, updateField, FieldMetareaderAccess +from amcat4.models import AmcatType, ElasticType, Field, UpdateField, updateField, FieldMetareaderAccess # given an elastic field type, infer @@ -79,30 +79,10 @@ def get_default_field(elastic_type: ElasticType): # default fields when a new index is created DEFAULT_FIELDS = { - "text": Field( - type="text", - elastic_type="text", - metareader=FieldMetareaderAccess(access="none"), - client_display=FieldClientDisplay(in_list=True), - ), - "title": Field( - type="text", - elastic_type="text", - metareader=FieldMetareaderAccess(access="read"), - client_display=FieldClientDisplay(in_list=True), - ), - "date": Field( - type="date", - elastic_type="date", - metareader=FieldMetareaderAccess(access="read"), - client_display=FieldClientDisplay(in_list=True, in_list_summary=True), - ), - "url": Field( - type="keyword", - elastic_type="wildcard", - metareader=FieldMetareaderAccess(access="read"), - client_display=FieldClientDisplay(in_list=True), - ), + "text": Field(type="text", elastic_type="text", metareader=FieldMetareaderAccess(access="none"), client_settings={}), + "title": Field(type="text", elastic_type="text", metareader=FieldMetareaderAccess(access="read"), client_settings={}), + "date": Field(type="date", elastic_type="date", metareader=FieldMetareaderAccess(access="read"), client_settings={}), + "url": Field(type="keyword", elastic_type="wildcard", metareader=FieldMetareaderAccess(access="read"), client_settings={}), } diff --git a/amcat4/models.py b/amcat4/models.py index 91f5961..7d6096b 100644 --- a/amcat4/models.py +++ b/amcat4/models.py @@ -1,6 +1,6 @@ import pydantic from pydantic import BaseModel -from typing import Annotated, Literal, NewType +from typing import Annotated, Any, Literal, NewType AmcatType = Literal["text", "date", "boolean", "keyword", "number", "object", "vector", "geo"] @@ -43,14 +43,6 @@ class SnippetParams(BaseModel): match_chars: Annotated[int, pydantic.Field(ge=1)] = 50 -class FieldClientDisplay(BaseModel): - """Client display settings for a specific field.""" - - in_list: bool = False - in_list_summary: bool = False - in_document: bool = True - - class FieldMetareaderAccess(BaseModel): """Metareader access for a specific field.""" @@ -59,12 +51,13 @@ class FieldMetareaderAccess(BaseModel): class Field(BaseModel): - """Settings for a field.""" + """Settings for a field. Some settings, such as metareader, have a strict model because they are used + server side. Others, such as client_settings, are free-form and can be used by the client to store settings.""" type: AmcatType elastic_type: ElasticType metareader: FieldMetareaderAccess = FieldMetareaderAccess() - client_display: FieldClientDisplay = FieldClientDisplay() + client_settings: dict[str, Any] = {} class CreateField(BaseModel): @@ -72,14 +65,14 @@ class CreateField(BaseModel): elastic_type: ElasticType metareader: FieldMetareaderAccess | None = None - client_display: FieldClientDisplay | None = None + client_settings: dict[str, Any] | None = None class UpdateField(BaseModel): """Model for updating a field""" metareader: FieldMetareaderAccess | None = None - client_display: FieldClientDisplay | None = None + client_settings: dict[str, Any] | None = None def updateField(field: Field, update: UpdateField | Field): diff --git a/amcat4/query.py b/amcat4/query.py index 50d7b3b..4b5b0cd 100644 --- a/amcat4/query.py +++ b/amcat4/query.py @@ -1,6 +1,7 @@ """ All things query """ + from math import ceil from typing import ( @@ -158,7 +159,7 @@ def query_documents( :param kwargs: Additional elements passed to Elasticsearch.search() :return: a QueryResult, or None if there is not scroll result anymore """ - print(index, fields, queries) + print(page) if fields is not None and not isinstance(fields, list): raise ValueError("fields should be a list") From ff21ae2b1ca32c7360e6cfbd0843378454ea3b02 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Fri, 1 Mar 2024 12:40:51 +0100 Subject: [PATCH 27/80] stupid aggregate stuff --- amcat4/aggregate.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/amcat4/aggregate.py b/amcat4/aggregate.py index afad219..9989fe5 100644 --- a/amcat4/aggregate.py +++ b/amcat4/aggregate.py @@ -2,8 +2,10 @@ Aggregate queries """ +import copy from datetime import datetime from itertools import islice +import json from typing import Any, Mapping, Iterable, Union, Tuple, Sequence, List, Dict from amcat4.date_mappings import interval_mapping @@ -38,7 +40,7 @@ def __init__(self, field: str, interval: str | None = None, name: str | None = N self.name = field def __repr__(self): - return f"" + return f"" def query(self): if not self.ftype: @@ -224,6 +226,7 @@ def _aggregate_results( yield rows, None elif any(ax.field == "_query" for ax in axes): + # Path 2 # We cannot run the aggregation for multiple queries at once, so we loop over queries # and recursively call _aggregate_results with one query at a time (which then uses path 3). @@ -232,6 +235,7 @@ def _aggregate_results( # Strip off _query axis and run separate aggregation for each query i = [ax.field for ax in axes].index("_query") _axes = axes[:i] + axes[(i + 1) :] + query_items = list(queries.items()) for label, query in query_items: last_query = label == query_items[-1][0] @@ -246,6 +250,8 @@ def _aggregate_results( continue for rows, after_buckets in _aggregate_results(index, _axes, {label: query}, filters, aggregations, after=after): + after_buckets = copy.deepcopy(after_buckets) + # insert label into the right position on the result tuple rows = [result_tuple[:i] + (label,) + result_tuple[i:] for result_tuple in rows] @@ -323,7 +329,7 @@ def query_aggregate( # the last_after value serves as a pagination cursor. Once we have > [stop_after] rows, # we return the data and the last_after cursor. If the user needs to collect the rest, # they need to paginate - stop_after = 1000 + stop_after = 500 gen = _aggregate_results(index, axes, queries, filters, aggregations, after) data = list() last_after = None From 5649592d0a67e16220104c06c02a4517743b139b Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Fri, 1 Mar 2024 18:35:54 +0100 Subject: [PATCH 28/80] aggregation stuff --- amcat4/aggregate.py | 19 ++----------------- amcat4/api/query.py | 7 ++++++- amcat4/query.py | 1 - 3 files changed, 8 insertions(+), 19 deletions(-) diff --git a/amcat4/aggregate.py b/amcat4/aggregate.py index 9989fe5..c3d469d 100644 --- a/amcat4/aggregate.py +++ b/amcat4/aggregate.py @@ -49,26 +49,11 @@ def query(self): if self.ftype == "date": if m := interval_mapping(self.interval): return {self.name: {"terms": {"field": m.fieldname(self.field)}}} - # KW: auto_date_histogram is not supported within composite. - # Either we let client handle auto determining interval, or we drop composite - # (dropping composite matter relates to comment by WvA below) - # if self.interval == "auto": - # return { - # self.name: { - # "auto_date_histogram": { - # "field": self.field, - # "buckets": 30, - # "minimum_interval": "day", - # "format": "yyyy-MM-dd", - # } - # } - # } - return {self.name: {"date_histogram": {"field": self.field, "calendar_interval": self.interval}}} else: return {self.name: {"histogram": {"field": self.field, "interval": self.interval}}} else: - return {self.name: {"terms": {"field": self.field}}} + return {self.name: {"terms": {"field": self.field, "order": "desc"}}} def get_value(self, values): value = values[self.name] @@ -329,7 +314,7 @@ def query_aggregate( # the last_after value serves as a pagination cursor. Once we have > [stop_after] rows, # we return the data and the last_after cursor. If the user needs to collect the rest, # they need to paginate - stop_after = 500 + stop_after = 249 gen = _aggregate_results(index, axes, queries, filters, aggregations, after) data = list() last_after = None diff --git a/amcat4/api/query.py b/amcat4/api/query.py index d68f362..52ca336 100644 --- a/amcat4/api/query.py +++ b/amcat4/api/query.py @@ -313,7 +313,12 @@ def query_aggregate_post( ) results = aggregate.query_aggregate( - indices, _axes, _aggregations, queries=_standardize_queries(queries), filters=_standardize_filters(filters) + indices, + _axes, + _aggregations, + queries=_standardize_queries(queries), + filters=_standardize_filters(filters), + after=after, ) return { diff --git a/amcat4/query.py b/amcat4/query.py index 4b5b0cd..beeb4c0 100644 --- a/amcat4/query.py +++ b/amcat4/query.py @@ -159,7 +159,6 @@ def query_documents( :param kwargs: Additional elements passed to Elasticsearch.search() :return: a QueryResult, or None if there is not scroll result anymore """ - print(page) if fields is not None and not isinstance(fields, list): raise ValueError("fields should be a list") From 86fccfd2f8e0939287e2f1d78416bc0a2022f603 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Wed, 6 Mar 2024 13:40:18 +0100 Subject: [PATCH 29/80] fixed bug in aggregation pagination --- .vscode/settings.json | 1 + amcat4/aggregate.py | 10 ++++++---- amcat4/elastic.py | 1 + 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 2e83fb4..412dd90 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -8,6 +8,7 @@ "mypy.enabled": true, "mypy.runUsingActiveInterpreter": true, "python.analysis.typeCheckingMode": "basic", + "python.analysis.autoImportCompletions": true, "flake8.args": ["--max-line-length=127", "--ignore=E203"] } diff --git a/amcat4/aggregate.py b/amcat4/aggregate.py index c3d469d..fdcfaba 100644 --- a/amcat4/aggregate.py +++ b/amcat4/aggregate.py @@ -230,9 +230,9 @@ def _aggregate_results( # pagination. Since we loop over queries, we add the _query value. # Then after continuing from the right query, we remove this _query # key so that the after dict is as elastic expects it - after_query = after.pop("_query", None) - if after_query != label: + if after.get("_query") != label: continue + after.pop("_query", None) for rows, after_buckets in _aggregate_results(index, _axes, {label: query}, filters, aggregations, after=after): after_buckets = copy.deepcopy(after_buckets) @@ -250,6 +250,9 @@ def _aggregate_results( after_buckets["_query"] = label yield rows, after_buckets + # after only applies to the first query + after = None + else: # Path 3 # Run an aggregation with one or more axes. If after is not None, we continue from there. @@ -314,7 +317,7 @@ def query_aggregate( # the last_after value serves as a pagination cursor. Once we have > [stop_after] rows, # we return the data and the last_after cursor. If the user needs to collect the rest, # they need to paginate - stop_after = 249 + stop_after = 1000 gen = _aggregate_results(index, axes, queries, filters, aggregations, after) data = list() last_after = None @@ -323,5 +326,4 @@ def query_aggregate( last_after = after if len(data) > stop_after: gen.close() - return AggregateResult(axes, aggregations, data, count_column="n", after=last_after) diff --git a/amcat4/elastic.py b/amcat4/elastic.py index ef0a508..563d047 100644 --- a/amcat4/elastic.py +++ b/amcat4/elastic.py @@ -8,6 +8,7 @@ {auth: [{email: role}], guest_role: role} """ + import functools import logging From 5557200ae779112d270c1860dad20b9f6ebc9c64 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Sun, 10 Mar 2024 14:12:47 +0100 Subject: [PATCH 30/80] added middleware for auto gzip (if request header specifies accepting gzip) --- amcat4/api/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/amcat4/api/__init__.py b/amcat4/api/__init__.py index 35e8e43..4ccdea8 100644 --- a/amcat4/api/__init__.py +++ b/amcat4/api/__init__.py @@ -3,6 +3,7 @@ from fastapi import FastAPI, Request from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware +from fastapi.middleware.gzip import GZipMiddleware from amcat4.api.index import app_index from amcat4.api.info import app_info @@ -37,6 +38,7 @@ allow_methods=["*"], allow_headers=["*"], ) +app.add_middleware(GZipMiddleware, minimum_size=1000) @app.exception_handler(ValueError) From 1c6651d770da4382ca9a1e239c9178f516644fd2 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Sun, 17 Mar 2024 18:01:24 +0100 Subject: [PATCH 31/80] custom identifiers and no more defaults" --- amcat4/api/__init__.py | 1 + amcat4/api/index.py | 4 ++-- amcat4/fields.py | 36 ++++++++++++++-------------------- amcat4/index.py | 43 +++++++++++++++++++++++------------------ amcat4/models.py | 4 +++- tests/test_api_index.py | 4 ++-- tests/test_elastic.py | 4 ++-- tests/test_index.py | 4 ++-- 8 files changed, 51 insertions(+), 49 deletions(-) diff --git a/amcat4/api/__init__.py b/amcat4/api/__init__.py index 4ccdea8..2cba783 100644 --- a/amcat4/api/__init__.py +++ b/amcat4/api/__init__.py @@ -10,6 +10,7 @@ from amcat4.api.query import app_query from amcat4.api.users import app_users + app = FastAPI( title="AmCAT4", description=__doc__, diff --git a/amcat4/api/index.py b/amcat4/api/index.py index 646d01a..ecbbdbe 100644 --- a/amcat4/api/index.py +++ b/amcat4/api/index.py @@ -239,7 +239,7 @@ def create_fields( types[field] = value if len(update_fields) > 0: - index.set_fields(ix, update_fields) + index.update_fields(ix, update_fields) return "", HTTPStatus.NO_CONTENT @@ -263,7 +263,7 @@ def update_fields( """ check_role(user, index.Role.WRITER, ix) - index.set_fields(ix, fields) + index.update_fields(ix, fields) return "", HTTPStatus.NO_CONTENT diff --git a/amcat4/fields.py b/amcat4/fields.py index 5a3ba52..5ee10f4 100644 --- a/amcat4/fields.py +++ b/amcat4/fields.py @@ -20,7 +20,7 @@ # from amcat4.api.common import py2dict from amcat4.config import get_settings from amcat4.elastic import es -from amcat4.models import AmcatType, ElasticType, Field, UpdateField, updateField, FieldMetareaderAccess +from amcat4.models import AmcatType, CreateField, ElasticType, Field, UpdateField, updateField, FieldMetareaderAccess # given an elastic field type, infer @@ -77,15 +77,6 @@ def get_default_field(elastic_type: ElasticType): return Field(type=amcat_type, elastic_type=elastic_type, metareader=get_default_metareader(amcat_type)) -# default fields when a new index is created -DEFAULT_FIELDS = { - "text": Field(type="text", elastic_type="text", metareader=FieldMetareaderAccess(access="none"), client_settings={}), - "title": Field(type="text", elastic_type="text", metareader=FieldMetareaderAccess(access="read"), client_settings={}), - "date": Field(type="date", elastic_type="date", metareader=FieldMetareaderAccess(access="read"), client_settings={}), - "url": Field(type="keyword", elastic_type="wildcard", metareader=FieldMetareaderAccess(access="read"), client_settings={}), -} - - def coerce_type(value: Any, elastic_type: ElasticType): """ Coerces values into the respective type in elastic @@ -104,28 +95,29 @@ def coerce_type(value: Any, elastic_type: ElasticType): return value -def create_elastic_fields(index: str, fields: dict[str, ElasticType]): +def create_fields(index: str, fields: dict[str, CreateField]): mapping: dict[str, Any] = {} current_fields = {k: v for k, v in _get_index_fields(index)} - for field, elastic_type in fields.items(): - if TYPEMAP_ES_TO_AMCAT.get(elastic_type) is None: - raise ValueError(f"Field type {elastic_type} not supported by AmCAT") + for field, settings in fields.items(): + if TYPEMAP_ES_TO_AMCAT.get(settings.elastic_type) is None: + raise ValueError(f"Field type {settings.elastic_type} not supported by AmCAT") current_type = current_fields.get(field) if current_type is not None: - if current_type != elastic_type: + if current_type != settings.elastic_type: raise ValueError( - f"Field '{field}' already exists with type '{current_type}'. Cannot change type to '{elastic_type}'" + f"Field '{field}' already exists with type '{current_type}'. Cannot change type to '{settings.elastic_type}'" ) continue - mapping[field] = {"type": elastic_type} + mapping[field] = {"type": settings.elastic_type} - if elastic_type in ["date"]: + if settings.elastic_type in ["date"]: mapping[field]["format"] = "strict_date_optional_time" es().indices.put_mapping(index=index, properties=mapping) + update_fields(index, fields) def _fields_to_elastic(fields: dict[str, Field]) -> list[dict]: @@ -138,7 +130,7 @@ def _fields_from_elastic( return {fs["field"]: Field.model_validate(fs["settings"]) for fs in fields} -def set_fields(index: str, new_fields: dict[str, UpdateField] | dict[str, Field]): +def update_fields(index: str, new_fields: dict[str, UpdateField] | dict[str, Field] | dict[str, CreateField]): """ Set the fields settings for this index. Only updates fields that already exist. type and elastic_type cannot be changed. @@ -167,8 +159,10 @@ def set_fields(index: str, new_fields: dict[str, UpdateField] | dict[str, Field] def _get_index_fields(index: str) -> Iterator[tuple[str, ElasticType]]: r = es().indices.get_mapping(index=index) - for k, v in r[index]["mappings"]["properties"].items(): - yield k, v.get("type", "object") + + if len(r[index]["mappings"]) > 0: + for k, v in r[index]["mappings"]["properties"].items(): + yield k, v.get("type", "object") def get_fields(index: str) -> dict[str, Field]: diff --git a/amcat4/index.py b/amcat4/index.py index 9375aeb..998e178 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -46,13 +46,12 @@ from amcat4.config import get_settings from amcat4.elastic import es from amcat4.fields import ( - DEFAULT_FIELDS, coerce_type, - create_elastic_fields, + create_fields, get_fields, - set_fields, + update_fields, ) -from amcat4.models import ElasticType +from amcat4.models import CreateField, ElasticType, Field class Role(IntEnum): @@ -141,17 +140,13 @@ def create_index( """ Create a new index in elasticsearch and register it with this AmCAT instance """ - default_mapping = {} - for field, settings in DEFAULT_FIELDS.items(): - default_mapping[field] = {"type": settings.elastic_type} - try: get_index(index) raise ValueError(f'Index "{index}" already exists') except IndexDoesNotExist: pass - es().indices.create(index=index, mappings={"properties": default_mapping}) + es().indices.create(index=index, mappings={"properties": {}}) register_index( index, @@ -161,7 +156,7 @@ def create_index( admin=admin, ) - set_fields(index, DEFAULT_FIELDS) + # update_fields(index, DEFAULT_FIELDS) def register_index( @@ -406,17 +401,26 @@ def delete_user(email: str) -> None: set_role(ix.id, email, None) -def _get_hash(document: dict) -> str: +def _get_hash(document: dict, field_settings: dict[str, Field]) -> str: """ Get the hash for a document """ - hash_str = json.dumps(document, sort_keys=True, ensure_ascii=True, default=str).encode("ascii") + + identifiers = [k for k, v in field_settings.items() if v.identifier] + if len(identifiers) == 0: + # if no identifiers specified, id is hash of entire document + hash_values = document + else: + # if identifiers specified, id is hash of those fields + hash_values = {k: document.get(k) for k in identifiers if k in document} + + hash_str = json.dumps(hash_values, sort_keys=True, ensure_ascii=True, default=str).encode("ascii") m = hashlib.sha224() m.update(hash_str) return m.hexdigest() -def upload_documents(index: str, documents: list[dict[str, Any]], types: dict[str, ElasticType] | None = None) -> None: +def upload_documents(index: str, documents: list[dict[str, Any]], fields: dict[str, CreateField] | None = None) -> None: """ Upload documents to this index @@ -425,20 +429,21 @@ def upload_documents(index: str, documents: list[dict[str, Any]], types: dict[st :param fields: A mapping of fieldname:UpdateField for field types """ - if types: - create_elastic_fields(index, types) + if fields: + create_fields(index, fields) def es_actions(index, documents): - field_types = get_fields(index) + field_settings = get_fields(index) for document in documents: + for key in document.keys(): if key == "_id": continue - if key not in field_types: + if key not in field_settings: raise ValueError(f"The type for field {key} is not yet specified") - document[key] = coerce_type(document[key], field_types[key].elastic_type) + document[key] = coerce_type(document[key], field_settings[key].elastic_type) if "_id" not in document: - document["_id"] = _get_hash(document) + document["_id"] = _get_hash(document, field_settings) yield {"_index": index, **document} actions = list(es_actions(index, documents)) diff --git a/amcat4/models.py b/amcat4/models.py index 7d6096b..4e9ef96 100644 --- a/amcat4/models.py +++ b/amcat4/models.py @@ -56,6 +56,7 @@ class Field(BaseModel): type: AmcatType elastic_type: ElasticType + identifier: bool = False metareader: FieldMetareaderAccess = FieldMetareaderAccess() client_settings: dict[str, Any] = {} @@ -64,6 +65,7 @@ class CreateField(BaseModel): """Model for creating a field""" elastic_type: ElasticType + identifier: bool = False metareader: FieldMetareaderAccess | None = None client_settings: dict[str, Any] | None = None @@ -75,7 +77,7 @@ class UpdateField(BaseModel): client_settings: dict[str, Any] | None = None -def updateField(field: Field, update: UpdateField | Field): +def updateField(field: Field, update: UpdateField | Field | CreateField): for key in update.model_fields_set: setattr(field, key, getattr(update, key)) return field diff --git a/tests/test_api_index.py b/tests/test_api_index.py index a95623d..8c2e6e3 100644 --- a/tests/test_api_index.py +++ b/tests/test_api_index.py @@ -2,7 +2,7 @@ from amcat4 import elastic -from amcat4.index import get_guest_role, Role, set_guest_role, set_role, remove_role, set_fields +from amcat4.index import get_guest_role, Role, set_guest_role, set_role, remove_role, update_fields from amcat4.models import Field from tests.tools import build_headers, post_json, get_json, check, refresh @@ -249,7 +249,7 @@ def test_name_description(client, index, index_name, user, admin): assert indices[index_name]["description"] == "test2" # can set and get summary field - set_fields(index_name, {"party": Field(type="keyword")}) + update_fields(index_name, {"party": Field(type="keyword")}) refresh() check( client.put( diff --git a/tests/test_elastic.py b/tests/test_elastic.py index 910e7a0..c98cdbf 100644 --- a/tests/test_elastic.py +++ b/tests/test_elastic.py @@ -7,7 +7,7 @@ update_document, update_tag_by_query, ) -from amcat4.fields import set_fields, get_fields, get_field_values +from amcat4.fields import update_fields, get_fields, get_field_values from amcat4.query import query_documents from tests.conftest import upload @@ -30,7 +30,7 @@ def test_upload_retrieve_document(index): def test_data_coerced(index): """Are field values coerced to the correct field type""" - set_fields(index, {"i": "long"}) + update_fields(index, {"i": "long"}) a = dict(_id="DoccyMcDocface", text="text", title="test-numeric", date="2022-12-13", i="1") upload_documents(index, [a]) d = get_document(index, "DoccyMcDocface") diff --git a/tests/test_index.py b/tests/test_index.py index bd12acc..517b9f6 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -24,7 +24,7 @@ set_global_role, set_guest_role, set_role, - set_fields, + update_fields, ) from amcat4.models import Field from tests.tools import refresh @@ -165,7 +165,7 @@ def test_summary_field(index): modify_index(index, summary_field="doesnotexist") with pytest.raises(Exception): modify_index(index, summary_field="title") - set_fields(index, {"party": Field(type="keyword")}) + update_fields(index, {"party": Field(type="keyword", elastic_type="keyword")}) modify_index(index, summary_field="party") assert get_index(index).summary_field == "party" modify_index(index, summary_field="date") From d8500c706c3605d710f54d9e8efcd661774af89f Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Wed, 20 Mar 2024 15:31:50 +0100 Subject: [PATCH 32/80] some upload document things --- amcat4/api/index.py | 28 ++++++++-------------------- amcat4/fields.py | 2 +- amcat4/index.py | 6 +++--- amcat4/models.py | 1 + 4 files changed, 13 insertions(+), 24 deletions(-) diff --git a/amcat4/api/index.py b/amcat4/api/index.py index ecbbdbe..3c0052f 100644 --- a/amcat4/api/index.py +++ b/amcat4/api/index.py @@ -8,7 +8,7 @@ from fastapi import APIRouter, HTTPException, Response, status, Depends, Body from pydantic import BaseModel -from amcat4 import index +from amcat4 import index, fields as index_fields from amcat4.api.auth import authenticated_user, authenticated_writer, check_role from amcat4.index import refresh_system_index, remove_role, set_role @@ -135,9 +135,9 @@ def delete_index(ix: str, user: str = Depends(authenticated_user)): def upload_documents( ix: str, documents: Annotated[list[dict[str, Any]], Body(description="The documents to upload")], - types: Annotated[ - dict[str, ElasticType] | None, - Body(description="If a field in documents does not yet exist, you need to specify an elastic type"), + new_fields: Annotated[ + dict[str, CreateField] | None, + Body(description="If a field in documents does not yet exist, you can create it on the spot"), ] = None, user: str = Depends(authenticated_user), ): @@ -145,8 +145,7 @@ def upload_documents( Upload documents to this server. Returns a list of ids for the uploaded documents """ check_role(user, index.Role.WRITER, ix) - - return index.upload_documents(ix, documents, types) + return index.upload_documents(ix, documents, new_fields) @app_index.get("/{ix}/documents/{docid}") @@ -221,25 +220,14 @@ def delete_document(ix: str, docid: str, user: str = Depends(authenticated_user) @app_index.post("/{ix}/fields") def create_fields( ix: str, - fields: Annotated[dict[str, ElasticType | CreateField], Body(description="")], + fields: Annotated[dict[str, CreateField], Body(description="")], user: str = Depends(authenticated_user), ): """ Create fields """ check_role(user, index.Role.WRITER, ix) - - types = {} - update_fields = {} - for field, value in fields.items(): - if isinstance(value, CreateField): - types[field] = value.elastic_type - update_fields[field] = UpdateField(**value.model_dump(exclude_none=True)) - else: - types[field] = value - - if len(update_fields) > 0: - index.update_fields(ix, update_fields) + index_fields.create_fields(ix, fields) return "", HTTPStatus.NO_CONTENT @@ -263,7 +251,7 @@ def update_fields( """ check_role(user, index.Role.WRITER, ix) - index.update_fields(ix, fields) + index_fields.update_fields(ix, fields) return "", HTTPStatus.NO_CONTENT diff --git a/amcat4/fields.py b/amcat4/fields.py index 5ee10f4..71c9f86 100644 --- a/amcat4/fields.py +++ b/amcat4/fields.py @@ -195,7 +195,7 @@ def get_fields(index: str) -> dict[str, Field]: if field not in system_index_fields: update_system_index = True - fields[field] = Field(type=amcat_type, elastic_type=elastic_type, metareader=get_default_metareader(amcat_type)) + fields[field] = get_default_field(elastic_type) else: fields[field] = system_index_fields[field] diff --git a/amcat4/index.py b/amcat4/index.py index 998e178..f9accf7 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -49,7 +49,6 @@ coerce_type, create_fields, get_fields, - update_fields, ) from amcat4.models import CreateField, ElasticType, Field @@ -420,7 +419,7 @@ def _get_hash(document: dict, field_settings: dict[str, Field]) -> str: return m.hexdigest() -def upload_documents(index: str, documents: list[dict[str, Any]], fields: dict[str, CreateField] | None = None) -> None: +def upload_documents(index: str, documents: list[dict[str, Any]], fields: dict[str, CreateField] | None = None): """ Upload documents to this index @@ -447,7 +446,8 @@ def es_actions(index, documents): yield {"_index": index, **document} actions = list(es_actions(index, documents)) - elasticsearch.helpers.bulk(es(), actions) + n_submitted, created = elasticsearch.helpers.bulk(es(), actions) + return dict(n_submitted=n_submitted, created=created) def get_document(index: str, doc_id: str, **kargs) -> dict: diff --git a/amcat4/models.py b/amcat4/models.py index 4e9ef96..85d2e88 100644 --- a/amcat4/models.py +++ b/amcat4/models.py @@ -78,6 +78,7 @@ class UpdateField(BaseModel): def updateField(field: Field, update: UpdateField | Field | CreateField): + for key in update.model_fields_set: setattr(field, key, getattr(update, key)) return field From 9f2c7ed821984dd34381ef961138844987051946 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Wed, 20 Mar 2024 19:01:22 +0100 Subject: [PATCH 33/80] cleaning stuff up --- amcat4/api/index.py | 58 +++++++++++++++++++++++++++++++++++++-------- amcat4/index.py | 17 ++++++------- amcat4/query.py | 2 +- 3 files changed, 58 insertions(+), 19 deletions(-) diff --git a/amcat4/api/index.py b/amcat4/api/index.py index 3c0052f..9c366e1 100644 --- a/amcat4/api/index.py +++ b/amcat4/api/index.py @@ -1,12 +1,14 @@ """API Endpoints for document and index management.""" from http import HTTPStatus +from operator import is_ from typing import Annotated, Any, Literal import elasticsearch from elastic_transport import ApiError from fastapi import APIRouter, HTTPException, Response, status, Depends, Body from pydantic import BaseModel +from datetime import datetime from amcat4 import index, fields as index_fields from amcat4.api.auth import authenticated_user, authenticated_writer, check_role @@ -74,9 +76,10 @@ def create_index(new_index: NewIndex, current_user: str = Depends(authenticated_ class ChangeIndex(BaseModel): """Form to update an existing index.""" - guest_role: Literal["ADMIN", "WRITER", "READER", "METAREADER", "NONE"] | None = "NONE" name: str | None = None description: str | None = None + guest_role: Literal["ADMIN", "WRITER", "READER", "METAREADER", "NONE"] | None = None + archive: bool | None = None @app_index.put("/{ix}") @@ -89,20 +92,22 @@ def modify_index(ix: str, data: ChangeIndex, user: str = Depends(authenticated_u User needs admin rights on the index """ check_role(user, index.Role.ADMIN, ix) - guest_role, remove_guest_role = index.Role.NONE, False - if data.guest_role: - role = data.guest_role - if role == "NONE": - remove_guest_role = True - else: - guest_role = index.Role[role] + guest_role = index.Role[data.guest_role] if data.guest_role is not None else None + archived = None + if data.archive is not None: + d = index.get_index(ix) + is_archived = d.archived is not None and d.archived != "" + if is_archived != data.archive: + archived = str(datetime.now()) if data.archive else "" index.modify_index( ix, name=data.name, description=data.description, guest_role=guest_role, - remove_guest_role=remove_guest_role, + archived=archived, + # remove_guest_role=remove_guest_role, + # unarchive=unarchive, ) refresh_system_index() @@ -124,11 +129,44 @@ def view_index(ix: str, user: str = Depends(authenticated_user)): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Index {ix} does not exist") +@app_index.post("/{ix}/archive", status_code=status.HTTP_204_NO_CONTENT, response_class=Response) +def archive_index( + ix: str, + archived: Annotated[bool, Body(description="Boolean for setting archived to true or false")], + user: str = Depends(authenticated_user), +): + """Archive or unarchive the index. When an index is archived, it restricts usage, and adds a timestamp for when + it was archived. An index can only be deleted if it has been archived for a specific amount of time.""" + check_role(user, index.Role.ADMIN, ix) + try: + d = index.get_index(ix) + is_archived = d.archived is not None + if is_archived == archived: + return + archived_date = datetime.now() if archived else None + index.modify_index(ix, archived=archived_date) + + except index.IndexDoesNotExist: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Index {ix} does not exist") + + @app_index.delete("/{ix}", status_code=status.HTTP_204_NO_CONTENT, response_class=Response) def delete_index(ix: str, user: str = Depends(authenticated_user)): """Delete the index.""" check_role(user, index.Role.ADMIN, ix) - index.delete_index(ix) + min_archived_before_delete = 7 # days + + try: + d = index.get_index(ix) + if d.archived is None or (datetime.now() - d.archived).days < min_archived_before_delete: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Can only delete an index after it has been archived for at least {min_archived_before_delete} days", + ) + index.delete_index(ix) + + except index.IndexDoesNotExist: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Index {ix} does not exist") @app_index.post("/{ix}/documents", status_code=status.HTTP_201_CREATED) diff --git a/amcat4/index.py b/amcat4/index.py index f9accf7..a2bb74b 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -33,6 +33,7 @@ """ import collections +from datetime import datetime from enum import IntEnum from typing import Any, Iterable, Optional, Literal @@ -66,7 +67,7 @@ class Role(IntEnum): Index = collections.namedtuple( "Index", - ["id", "name", "description", "guest_role", "roles", "summary_field"], + ["id", "name", "description", "guest_role", "archived", "roles", "summary_field"], ) @@ -111,11 +112,13 @@ def list_known_indices(email: str | None = None) -> Iterable[Index]: def _index_from_elastic(index): src = index["_source"] guest_role = src.get("guest_role") + return Index( id=index["_id"], name=src.get("name", index["_id"]), description=src.get("description"), - guest_role=guest_role and guest_role != "NONE" and Role[guest_role.upper()], + guest_role=guest_role, + archived=src.get("archived"), roles=_roles_from_elastic(src.get("roles", [])), summary_field=src.get("summary_field"), ) @@ -264,7 +267,7 @@ def set_guest_role(index: str, guest_role: Optional[Role]): """ Set the guest role for this index. Set to None to disallow guest access """ - modify_index(index, guest_role=guest_role, remove_guest_role=(guest_role is None)) + modify_index(index, guest_role=Role.NONE if guest_role is None else guest_role) def modify_index( @@ -272,20 +275,18 @@ def modify_index( name: Optional[str] = None, description: Optional[str] = None, guest_role: Optional[Role] = None, - remove_guest_role=False, + archived: Optional[str] = None, summary_field=None, ): - doc = dict( name=name, description=description, guest_role=guest_role and guest_role.value, summary_field=summary_field, + archived=archived, ) - doc = {x: v for (x, v) in doc.items() if v} - if remove_guest_role: - doc["guest_role"] = Role.NONE.value + doc = {x: v for (x, v) in doc.items() if v is not None} if doc: es().update(index=get_settings().system_index, id=index, doc=doc) diff --git a/amcat4/query.py b/amcat4/query.py index beeb4c0..539de62 100644 --- a/amcat4/query.py +++ b/amcat4/query.py @@ -121,7 +121,7 @@ def as_dict(self) -> dict: def query_documents( - index: Union[str, Sequence[str]], + index: Union[str, list[str]], fields: list[FieldSpec] | None = None, queries: dict[str, str] | None = None, filters: dict[str, FilterSpec] | None = None, From fc8949889d376b355a38d7a33cc3275415344b1a Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Thu, 21 Mar 2024 16:45:44 +0100 Subject: [PATCH 34/80] upload things --- amcat4/api/index.py | 13 +++++++++++-- amcat4/index.py | 15 +++++++++------ 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/amcat4/api/index.py b/amcat4/api/index.py index 9c366e1..c5b86dc 100644 --- a/amcat4/api/index.py +++ b/amcat4/api/index.py @@ -34,7 +34,12 @@ def index_to_dict(ix: index.Index) -> dict: ix_dict = ix._asdict() guest_role_int = ix_dict.get("guest_role", 0) - ix_dict = dict(id=ix_dict["id"], name=ix_dict["name"], guest_role=index.Role(guest_role_int).name) + ix_dict = dict( + id=ix_dict["id"], + name=ix_dict["name"], + guest_role=index.Role(guest_role_int).name, + archived=ix_dict.get("archived", ""), + ) return ix_dict return [index_to_dict(ix) for ix in index.list_known_indices(current_user)] @@ -177,13 +182,17 @@ def upload_documents( dict[str, CreateField] | None, Body(description="If a field in documents does not yet exist, you can create it on the spot"), ] = None, + operation: Annotated[ + Literal["index", "update", "create"], + Body(description="The operation to perform. (the default, index, is like upsert)"), + ] = "index", user: str = Depends(authenticated_user), ): """ Upload documents to this server. Returns a list of ids for the uploaded documents """ check_role(user, index.Role.WRITER, ix) - return index.upload_documents(ix, documents, new_fields) + return index.upload_documents(ix, documents, new_fields, operation) @app_index.get("/{ix}/documents/{docid}") diff --git a/amcat4/index.py b/amcat4/index.py index a2bb74b..fbfe1e2 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -420,7 +420,9 @@ def _get_hash(document: dict, field_settings: dict[str, Field]) -> str: return m.hexdigest() -def upload_documents(index: str, documents: list[dict[str, Any]], fields: dict[str, CreateField] | None = None): +def upload_documents( + index: str, documents: list[dict[str, Any]], fields: dict[str, CreateField] | None = None, op_type="index" +): """ Upload documents to this index @@ -432,7 +434,7 @@ def upload_documents(index: str, documents: list[dict[str, Any]], fields: dict[s if fields: create_fields(index, fields) - def es_actions(index, documents): + def es_actions(index, documents, op_type): field_settings = get_fields(index) for document in documents: @@ -444,11 +446,12 @@ def es_actions(index, documents): document[key] = coerce_type(document[key], field_settings[key].elastic_type) if "_id" not in document: document["_id"] = _get_hash(document, field_settings) - yield {"_index": index, **document} + yield {"_op_type": op_type, "_index": index, **document} - actions = list(es_actions(index, documents)) - n_submitted, created = elasticsearch.helpers.bulk(es(), actions) - return dict(n_submitted=n_submitted, created=created) + actions = list(es_actions(index, documents, op_type)) + successes, failures = elasticsearch.helpers.bulk(es(), actions, stats_only=True, raise_on_error=False) + print(successes, failures) + return dict(successes=successes, failures=failures) def get_document(index: str, doc_id: str, **kargs) -> dict: From e6c1b70caa582f011bc282b9d6995c30e5a3b583 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Fri, 22 Mar 2024 16:04:13 +0100 Subject: [PATCH 35/80] upload things --- amcat4/fields.py | 7 ++++++- amcat4/index.py | 1 + 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/amcat4/fields.py b/amcat4/fields.py index 71c9f86..cce65c3 100644 --- a/amcat4/fields.py +++ b/amcat4/fields.py @@ -99,7 +99,12 @@ def create_fields(index: str, fields: dict[str, CreateField]): mapping: dict[str, Any] = {} current_fields = {k: v for k, v in _get_index_fields(index)} + new_fields: dict[str, CreateField] = {} + for field, settings in fields.items(): + if field not in current_fields: + new_fields[field] = settings + if TYPEMAP_ES_TO_AMCAT.get(settings.elastic_type) is None: raise ValueError(f"Field type {settings.elastic_type} not supported by AmCAT") @@ -117,7 +122,7 @@ def create_fields(index: str, fields: dict[str, CreateField]): mapping[field]["format"] = "strict_date_optional_time" es().indices.put_mapping(index=index, properties=mapping) - update_fields(index, fields) + update_fields(index, new_fields) def _fields_to_elastic(fields: dict[str, Field]) -> list[dict]: diff --git a/amcat4/index.py b/amcat4/index.py index fbfe1e2..4b25c2b 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -436,6 +436,7 @@ def upload_documents( def es_actions(index, documents, op_type): field_settings = get_fields(index) + for document in documents: for key in document.keys(): From 5a122ec180bd7e34e2074161a74fb7b222d14325 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Tue, 26 Mar 2024 15:06:24 +0100 Subject: [PATCH 36/80] unit test adventure --- amcat4/aggregate.py | 2 +- amcat4/api/auth.py | 9 ++- amcat4/api/index.py | 5 +- amcat4/api/query.py | 3 +- amcat4/api/users.py | 4 +- amcat4/index.py | 39 ++++++------- amcat4/query.py | 9 ++- tests/conftest.py | 32 ++++++----- tests/test_aggregate.py | 3 +- tests/test_api_documents.py | 21 +++---- tests/test_api_index.py | 32 +++++------ tests/test_api_metareader.py | 4 +- tests/test_api_pagination.py | 47 +++++++-------- tests/test_api_query.py | 107 +++++++++-------------------------- tests/test_api_user.py | 4 ++ tests/test_elastic.py | 4 +- tests/test_index.py | 2 +- 17 files changed, 137 insertions(+), 190 deletions(-) diff --git a/amcat4/aggregate.py b/amcat4/aggregate.py index fdcfaba..95900d1 100644 --- a/amcat4/aggregate.py +++ b/amcat4/aggregate.py @@ -318,7 +318,7 @@ def query_aggregate( # we return the data and the last_after cursor. If the user needs to collect the rest, # they need to paginate stop_after = 1000 - gen = _aggregate_results(index, axes, queries, filters, aggregations, after) + gen = _aggregate_results(indices, axes, queries, filters, aggregations, after) data = list() last_after = None for rows, after in gen: diff --git a/amcat4/api/auth.py b/amcat4/api/auth.py index 04c55ed..94327e4 100644 --- a/amcat4/api/auth.py +++ b/amcat4/api/auth.py @@ -1,4 +1,6 @@ """Helper methods for authentication and authorization.""" + +from argparse import ONE_OR_MORE import functools import logging from datetime import datetime @@ -12,7 +14,7 @@ from amcat4.models import FieldSpec from amcat4.config import get_settings, AuthOptions -from amcat4.index import Role, get_role, get_global_role +from amcat4.index import ADMIN_USER, GUEST_USER, Role, get_role, get_global_role from amcat4.fields import get_fields oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/token", auto_error=False) @@ -99,6 +101,7 @@ def check_role(user: str, required_role: Role, index: str, required_global_role: return get_role(index, user) # Global role check was false, so now check local role actual_role = get_role(index, user) + if get_settings().auth == AuthOptions.no_auth: return actual_role elif actual_role and actual_role >= required_role: @@ -179,9 +182,9 @@ async def authenticated_user(token: str = Depends(oauth2_scheme)) -> str: auth = get_settings().auth if token is None: if auth == AuthOptions.no_auth: - return "admin" + return ADMIN_USER elif auth == AuthOptions.allow_guests: - return "guest" + return GUEST_USER else: raise HTTPException( status_code=HTTP_401_UNAUTHORIZED, diff --git a/amcat4/api/index.py b/amcat4/api/index.py index c5b86dc..f14cd60 100644 --- a/amcat4/api/index.py +++ b/amcat4/api/index.py @@ -38,6 +38,7 @@ def index_to_dict(ix: index.Index) -> dict: id=ix_dict["id"], name=ix_dict["name"], guest_role=index.Role(guest_role_int).name, + description=ix_dict.get("description", ""), archived=ix_dict.get("archived", ""), ) return ix_dict @@ -178,7 +179,7 @@ def delete_index(ix: str, user: str = Depends(authenticated_user)): def upload_documents( ix: str, documents: Annotated[list[dict[str, Any]], Body(description="The documents to upload")], - new_fields: Annotated[ + fields: Annotated[ dict[str, CreateField] | None, Body(description="If a field in documents does not yet exist, you can create it on the spot"), ] = None, @@ -192,7 +193,7 @@ def upload_documents( Upload documents to this server. Returns a list of ids for the uploaded documents """ check_role(user, index.Role.WRITER, ix) - return index.upload_documents(ix, documents, new_fields, operation) + return index.upload_documents(ix, documents, fields, operation) @app_index.get("/{ix}/documents/{docid}") diff --git a/amcat4/api/query.py b/amcat4/api/query.py index 52ca336..c806c1f 100644 --- a/amcat4/api/query.py +++ b/amcat4/api/query.py @@ -211,7 +211,7 @@ def query_documents_post( }, ), ] = None, - per_page: Annotated[int, Body(description="Number of documents per page")] = 10, + per_page: Annotated[int, Body(le=200, description="Number of documents per page")] = 10, page: Annotated[int, Body(description="Which page to retrieve")] = 0, scroll: Annotated[ str | None, @@ -231,7 +231,6 @@ def query_documents_post( Returns a JSON object {data: [...], meta: {total_count, per_page, page_count, page|scroll_id}} """ - indices = index.split(",") fieldspecs = get_or_validate_allowed_fields(user, indices, _standardize_fieldspecs(fields)) diff --git a/amcat4/api/users.py b/amcat4/api/users.py index 511a0ec..7029a56 100644 --- a/amcat4/api/users.py +++ b/amcat4/api/users.py @@ -15,7 +15,7 @@ from amcat4 import index from amcat4.api.auth import authenticated_user, authenticated_admin, check_global_role from amcat4.config import get_settings, validate_settings -from amcat4.index import Role, set_global_role, get_global_role, user_exists +from amcat4.index import ADMIN_USER, GUEST_USER, Role, set_global_role, get_global_role, user_exists app_users = APIRouter(tags=["users"]) @@ -74,7 +74,7 @@ def _get_user(email, current_user): if current_user != email: check_global_role(current_user, Role.WRITER) global_role = get_global_role(email) - if email in ("admin", "guest") or global_role is None: + if email in (ADMIN_USER, GUEST_USER) or global_role is Role.NONE: raise HTTPException(404, detail=f"User {email} unknown") else: return {"email": email, "role": global_role.name} diff --git a/amcat4/index.py b/amcat4/index.py index 4b25c2b..3d93500 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -42,9 +42,10 @@ import elasticsearch.helpers from elasticsearch import NotFoundError +from httpx import get # from amcat4.api.common import py2dict -from amcat4.config import get_settings +from amcat4.config import AuthOptions, get_settings from amcat4.elastic import es from amcat4.fields import ( coerce_type, @@ -62,6 +63,7 @@ class Role(IntEnum): ADMIN = 40 +ADMIN_USER = "_admin" GUEST_USER = "_guest" GLOBAL_ROLES = "_global" @@ -111,13 +113,13 @@ def list_known_indices(email: str | None = None) -> Iterable[Index]: def _index_from_elastic(index): src = index["_source"] - guest_role = src.get("guest_role") + guest_role = src.get("guest_role", "NONE") return Index( id=index["_id"], name=src.get("name", index["_id"]), description=src.get("description"), - guest_role=guest_role, + guest_role=Role[guest_role], archived=src.get("archived"), roles=_roles_from_elastic(src.get("roles", [])), summary_field=src.get("summary_field"), @@ -149,7 +151,6 @@ def create_index( pass es().indices.create(index=index, mappings={"properties": {}}) - register_index( index, guest_role=guest_role or Role.NONE, @@ -158,8 +159,6 @@ def create_index( admin=admin, ) - # update_fields(index, DEFAULT_FIELDS) - def register_index( index: str, @@ -178,11 +177,6 @@ def register_index( raise ValueError(f"Index {index} is already registered") roles = [dict(email=admin, role="ADMIN")] if admin else [] - if guest_role is not None: - guest_role_int = guest_role.value - else: - guest_role_int = Role.NONE.value - es().index( index=system_index, id=index, @@ -190,7 +184,7 @@ def register_index( name=(name or index), roles=roles, description=description, - guest_role=guest_role_int, + guest_role=guest_role.name if guest_role is not None else "NONE", ), ) refresh_index(system_index) @@ -249,6 +243,7 @@ def set_role(index: str, email: str, role: Optional[Role]): if email not in roles_dict: return # Nothing to change del roles_dict[email] + es().update( index=system_index, id=index, @@ -281,7 +276,7 @@ def modify_index( doc = dict( name=name, description=description, - guest_role=guest_role and guest_role.value, + guest_role=guest_role.name if guest_role is not None else "NONE", summary_field=summary_field, archived=archived, ) @@ -341,6 +336,13 @@ def get_role(index: str, email: str) -> Role: if index == GLOBAL_ROLES: return Role.NONE + # are guests allowed? + if get_settings().auth == AuthOptions.authorized_users_only: + # only allow guests if authorized at server level + global_role = get_global_role(email, only_es=True) + if global_role == Role.NONE: + return Role.NONE + return get_guest_role(index) @@ -358,7 +360,7 @@ def get_guest_role(index: str) -> Role: except NotFoundError: raise IndexDoesNotExist(index) role = d["_source"].get("guest_role") - if role and role.lower() != "none": + if role and role in Role.__members__: return Role[role] return Role.NONE @@ -371,7 +373,7 @@ def get_global_role(email: str, only_es: bool = False) -> Role: """ # The 'admin' user is given to everyone in the no_auth scenario if only_es is False: - if email == get_settings().admin_email or email == "admin": + if email == get_settings().admin_email or email == ADMIN_USER: return Role.ADMIN return get_role(index=GLOBAL_ROLES, email=email) @@ -436,23 +438,22 @@ def upload_documents( def es_actions(index, documents, op_type): field_settings = get_fields(index) - for document in documents: for key in document.keys(): if key == "_id": continue if key not in field_settings: - raise ValueError(f"The type for field {key} is not yet specified") + raise ValueError(f"The type for field '{key}' is not yet specified") document[key] = coerce_type(document[key], field_settings[key].elastic_type) if "_id" not in document: document["_id"] = _get_hash(document, field_settings) yield {"_op_type": op_type, "_index": index, **document} actions = list(es_actions(index, documents, op_type)) + ids = [doc["_id"] for doc in actions] successes, failures = elasticsearch.helpers.bulk(es(), actions, stats_only=True, raise_on_error=False) - print(successes, failures) - return dict(successes=successes, failures=failures) + return dict(ids=ids, successes=successes, failures=failures) def get_document(index: str, doc_id: str, **kargs) -> dict: diff --git a/amcat4/query.py b/amcat4/query.py index 539de62..1a17dcd 100644 --- a/amcat4/query.py +++ b/amcat4/query.py @@ -133,7 +133,7 @@ def query_documents( scroll_id: str | None = None, highlight: bool = False, **kwargs, -) -> QueryResult: +) -> QueryResult | None: """ Conduct a query_string query, returning the found documents. @@ -167,11 +167,14 @@ def query_documents( kwargs["scroll"] = "2m" if (not scroll or scroll is True) else scroll if sort is not None: - kwargs["sort"] = sort + kwargs["sort"] = [] + for s in sort: + for k, v in s.items(): + kwargs["sort"].append({k: dict(v)}) if scroll_id: result = es().scroll(scroll_id=scroll_id, **kwargs) if not result["hits"]["hits"]: - return QueryResult(data=[]) + return None else: h = query_highlight(fields, highlight) if fields is not None else None body = build_body(queries, filters, h) diff --git a/tests/conftest.py b/tests/conftest.py index 96a3b87..5bd9c37 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,7 +16,7 @@ set_global_role, upload_documents, ) -from amcat4.models import UpdateField +from amcat4.models import CreateField from tests.middlecat_keypair import PUBLIC_KEY UNITS = [ @@ -136,25 +136,18 @@ def guest_index(): delete_index(index, ignore_missing=True) -def upload(index: str, docs: list[dict[str, Any]], fields: dict[str, UpdateField] | None = None): +def upload(index: str, docs: list[dict[str, Any]], fields: dict[str, CreateField] | None = None): """ Upload these docs to the index, giving them an incremental id, and flush """ - ids = [] - for i, doc in enumerate(docs): - id = str(i) - ids.append(id) - defaults = {"title": "title", "date": "2018-01-01", "text": "text", "_id": id} - for k, v in defaults.items(): - if k not in doc: - doc[k] = v - upload_documents(index, docs, fields) + res = upload_documents(index, docs, fields) refresh_index(index) - return ids + return res["ids"] TEST_DOCUMENTS = [ { + "which": 0, "cat": "a", "subcat": "x", "i": 1, @@ -162,6 +155,7 @@ def upload(index: str, docs: list[dict[str, Any]], fields: dict[str, UpdateField "text": "this is a text", }, { + "which": 1, "cat": "a", "subcat": "x", "i": 2, @@ -169,6 +163,7 @@ def upload(index: str, docs: list[dict[str, Any]], fields: dict[str, UpdateField "text": "a test text", }, { + "which": 2, "cat": "a", "subcat": "y", "i": 11, @@ -177,6 +172,7 @@ def upload(index: str, docs: list[dict[str, Any]], fields: dict[str, UpdateField "title": "bla", }, { + "which": 3, "cat": "b", "subcat": "y", "i": 31, @@ -191,7 +187,15 @@ def populate_index(index): upload( index, TEST_DOCUMENTS, - fields={"cat": UpdateField(type="keyword"), "subcat": UpdateField(type="keyword"), "i": UpdateField(type="long")}, + fields={ + "text": CreateField(elastic_type="text"), + "title": CreateField(elastic_type="keyword"), + "date": CreateField(elastic_type="date"), + "cat": CreateField(elastic_type="keyword"), + "subcat": CreateField(elastic_type="keyword"), + "i": CreateField(elastic_type="long"), + "which": CreateField(elastic_type="short"), + }, ) return TEST_DOCUMENTS @@ -214,7 +218,7 @@ def index_many(): upload( index, [dict(id=i, pagenr=abs(10 - i), text=text) for (i, text) in enumerate(["odd", "even"] * 10)], - fields={"id": UpdateField(type="long"), "pagenr": UpdateField(type="long")}, + fields={"id": CreateField(elastic_type="long"), "pagenr": CreateField(elastic_type="long")}, ) yield index delete_index(index, ignore_missing=True) diff --git a/tests/test_aggregate.py b/tests/test_aggregate.py index c01c493..c187614 100644 --- a/tests/test_aggregate.py +++ b/tests/test_aggregate.py @@ -3,6 +3,7 @@ from amcat4.aggregate import query_aggregate, Axis, Aggregation from amcat4.api.query import _standardize_queries +from amcat4.models import CreateField, Field from tests.conftest import upload from tests.tools import dictset @@ -114,7 +115,7 @@ def test_aggregate_datefunctions(index: str): "2018-03-07T23:59:00", # wednesday evening ] ] - upload(index, docs) + upload(index, docs, fields=dict(date=CreateField(elastic_type="date"))) assert q(Axis("date", interval="day")) == { date(2018, 1, 1): 2, date(2018, 1, 11): 1, diff --git a/tests/test_api_documents.py b/tests/test_api_documents.py index ea21d5c..83b0a44 100644 --- a/tests/test_api_documents.py +++ b/tests/test_api_documents.py @@ -7,22 +7,16 @@ def test_documents_unauthorized(client, index, user): docs = {"documents": []} check(client.post(f"index/{index}/documents", json=docs), 401) check( - client.post( - f"index/{index}/documents", json=docs, headers=build_headers(user=user) - ), + client.post(f"index/{index}/documents", json=docs, headers=build_headers(user=user)), 401, ) check(client.put(f"index/{index}/documents/1", json={}), 401) check( - client.put( - f"index/{index}/documents/1", json={}, headers=build_headers(user=user) - ), + client.put(f"index/{index}/documents/1", json={}, headers=build_headers(user=user)), 401, ) check(client.get(f"index/{index}/documents/1"), 401) - check( - client.get(f"index/{index}/documents/1", headers=build_headers(user=user)), 401 - ) + check(client.get(f"index/{index}/documents/1", headers=build_headers(user=user)), 401) def test_documents(client, index, user): @@ -33,9 +27,12 @@ def test_documents(client, index, user): f"index/{index}/documents", user=user, json={ - "documents": [ - {"_id": "id", "title": "a title", "text": "text", "date": "2020-01-01"} - ] + "documents": [{"_id": "id", "title": "a title", "text": "text", "date": "2020-01-01"}], + "fields": { + "title": {"elastic_type": "text"}, + "text": {"elastic_type": "text"}, + "date": {"elastic_type": "date"}, + }, }, ) url = f"index/{index}/documents/id" diff --git a/tests/test_api_index.py b/tests/test_api_index.py index 8c2e6e3..0105b04 100644 --- a/tests/test_api_index.py +++ b/tests/test_api_index.py @@ -2,8 +2,9 @@ from amcat4 import elastic -from amcat4.index import get_guest_role, Role, set_guest_role, set_role, remove_role, update_fields -from amcat4.models import Field +from amcat4.index import get_guest_role, Role, set_guest_role, set_role, remove_role +from amcat4.fields import update_fields +from amcat4.models import CreateField, Field, UpdateField from tests.tools import build_headers, post_json, get_json, check, refresh @@ -79,7 +80,12 @@ def test_fields_upload(client: TestClient, user: str, index: str): } for i, x in enumerate(["a", "a", "b"]) ], - "fields": {"x": "keyword"}, + "fields": { + "title": dict(elastic_type="text"), + "text": dict(elastic_type="text"), + "date": dict(elastic_type="date"), + "x": dict(elastic_type="keyword"), + }, } # You need METAREADER permissions to read fields, and WRITER to upload docs @@ -90,9 +96,11 @@ def test_fields_upload(client: TestClient, user: str, index: str): ) set_role(index, user, Role.METAREADER) + + ## can get fields fields = get_json(client, f"/index/{index}/fields", user=user) or {} - assert set(fields.keys()) == {"title", "date", "text", "url"} - assert fields["date"]["type"] == "date" + ## but should still be empty, since no fields were created + assert len(set(fields.keys())) == 0 check( client.post(f"/index/{index}/documents", headers=build_headers(user), json=body), 401, @@ -234,7 +242,6 @@ def test_name_description(client, index, index_name, user, admin): id=index_name, description="test2", guest_role="METAREADER", - summary_field="party", ), headers=build_headers(admin), ), @@ -247,16 +254,3 @@ def test_name_description(client, index, index_name, user, admin): indices = {ix["id"]: ix for ix in get_json(client, "/index") or []} assert indices[index]["description"] == "ooktest" assert indices[index_name]["description"] == "test2" - - # can set and get summary field - update_fields(index_name, {"party": Field(type="keyword")}) - refresh() - check( - client.put( - f"/index/{index_name}", - json=dict(summary_field="party"), - headers=build_headers(admin), - ), - 200, - ) - assert (get_json(client, f"/index/{index_name}", user=admin) or {})["summary_field"] == "party" diff --git a/tests/test_api_metareader.py b/tests/test_api_metareader.py index d5d6f26..b8dda4d 100644 --- a/tests/test_api_metareader.py +++ b/tests/test_api_metareader.py @@ -11,10 +11,10 @@ def create_index_metareader(client, index, admin): def set_metareader_access(client, index, admin, metareader): - client.post( + client.put( f"/index/{index}/fields", headers=build_headers(admin), - json={"text": {"type": "text", "metareader": metareader}}, + json={"text": {"metareader": metareader}}, ) diff --git a/tests/test_api_pagination.py b/tests/test_api_pagination.py index ac58863..5c1ada1 100644 --- a/tests/test_api_pagination.py +++ b/tests/test_api_pagination.py @@ -1,4 +1,5 @@ from amcat4.index import Role, set_role +from amcat4.models import CreateField from tests.conftest import upload from tests.tools import get_json, post_json @@ -10,56 +11,46 @@ def test_pagination(client, index, user): # TODO. Tests are not independent. test_pagination fails if run directly after other tests. # Probably delete_index doesn't fully delete - upload(index, docs=[{"i": i} for i in range(66)]) - url = f"/index/{index}/documents" - r = get_json(client, url, user=user, params={"sort": "i", "per_page": 20, "fields": ["i"]}) + upload(index, docs=[{"i": i} for i in range(66)], fields={"i": CreateField(elastic_type="long")}) + url = f"/index/{index}/query" + r = post_json(client, url, user=user, json={"sort": "i", "per_page": 20, "fields": ["i"]}, expected=200) + assert r["meta"]["per_page"] == 20 assert r["meta"]["page"] == 0 assert r["meta"]["page_count"] == 4 assert {h["i"] for h in r["results"]} == set(range(20)) - r = get_json( - client, - url, - user=user, - params={"sort": "i", "per_page": 20, "page": 3, "fields": ["i"]}, - ) + r = post_json(client, url, user=user, json={"sort": "i", "per_page": 20, "page": 3, "fields": ["i"]}, expected=200) assert r["meta"]["page"] == 3 assert {h["i"] for h in r["results"]} == {60, 61, 62, 63, 64, 65} - r = get_json(client, url, user=user, params={"sort": "i", "per_page": 20, "page": 4}) + r = post_json(client, url, user=user, json={"sort": "i", "per_page": 20, "page": 4, "fields": ["i"]}, expected=200) assert len(r["results"]) == 0 - # Test POST query - - r = post_json( - client, - f"/index/{index}/query", - expected=200, - user=user, - json={"sort": "i", "per_page": 20, "page": 3, "fields": ["i"]}, - ) - assert r["meta"]["page"] == 3 - assert {h["i"] for h in r["results"]} == {60, 61, 62, 63, 64, 65} def test_scroll(client, index, user): set_role(index, user, Role.READER) - upload(index, docs=[{"i": i} for i in range(66)]) - url = f"/index/{index}/documents" - r = get_json( + upload(index, docs=[{"i": i} for i in range(66)], fields={"i": CreateField(elastic_type="long")}) + url = f"/index/{index}/query" + r = post_json( client, url, user=user, - params={"sort": "i:desc", "per_page": 30, "scroll": "5m", "fields": ["i"]}, + json={"scroll": "5m", "sort": [{"i": {"order": "desc"}}], "per_page": 30, "fields": ["i"]}, + expected=200, ) + scroll_id = r["meta"]["scroll_id"] assert scroll_id is not None assert {h["i"] for h in r["results"]} == set(range(36, 66)) - r = get_json(client, url, user=user, params={"scroll_id": scroll_id}) + + r = post_json(client, url, user=user, json={"scroll_id": scroll_id}, expected=200) assert {h["i"] for h in r["results"]} == set(range(6, 36)) assert r["meta"]["scroll_id"] == scroll_id - r = get_json(client, url, user=user, params={"scroll_id": scroll_id}) + r = post_json(client, url, user=user, json={"scroll_id": scroll_id}, expected=200) assert {h["i"] for h in r["results"]} == set(range(6)) + # Scrolling past the edge should return 404 - get_json(client, url, user=user, params={"scroll_id": scroll_id}, expected=404) + post_json(client, url, user=user, json={"scroll_id": scroll_id}, expected=404) + # Test POST to query endpoint r = post_json( client, diff --git a/tests/test_api_query.py b/tests/test_api_query.py index e6fae6c..5f54f20 100644 --- a/tests/test_api_query.py +++ b/tests/test_api_query.py @@ -1,52 +1,16 @@ from amcat4.index import Role, refresh_index, set_role +from amcat4.models import CreateField, FieldSpec from amcat4.query import query_documents from tests.conftest import upload from tests.tools import get_json, post_json, dictset -def test_query_get(client, index_docs, user): - """Can we run a simple query?""" - - def q(**query_string): - return get_json( - client, - f"/index/{index_docs}/documents", - user=user, - params=query_string, - )["results"] - - def qi(**query_string): - return {int(doc["_id"]) for doc in q(**query_string)} - - # TODO make sure all auth is checked in test_api_query_auth - - # Query strings - assert qi(q="text") == {0, 1} - assert qi(q="test*") == {1, 2, 3} - - # Filters - assert qi(cat="a") == {0, 1, 2} - assert qi(cat="b", q="test*") == {3} - assert qi(date="2018-01-01") == {0, 3} - assert qi(date__gte="2018-02-01") == {1, 2} - assert qi(date__gt="2018-02-01") == {2} - assert qi(date__gte="2018-02-01", date__lt="2020-01-01") == {1} - - # Can we request specific fields? - all_fields = {"_id", "cat", "subcat", "i", "date", "text", "title"} - assert set(q()[0].keys()) == all_fields - assert set(q(fields="cat")[0].keys()) == {"_id", "cat"} - assert set(q(fields="date,title")[0].keys()) == {"_id", "date", "title"} - - def test_query_post(client, index_docs, user): def q(**body): - return post_json( - client, f"/index/{index_docs}/query", user=user, expected=200, json=body - )["results"] + return post_json(client, f"/index/{index_docs}/query", user=user, expected=200, json=body)["results"] def qi(**query_string): - return {int(doc["_id"]) for doc in q(**query_string)} + return {doc["which"] for doc in q(**query_string)} # Query strings assert qi(queries="text") == {0, 1} @@ -64,10 +28,10 @@ def qi(**query_string): assert qi(filters={"cat": {"values": ["a"]}}) == {0, 1, 2} # Can we request specific fields? - all_fields = {"_id", "cat", "subcat", "i", "date", "text", "title"} + all_fields = {"_id", "cat", "subcat", "i", "date", "text", "which"} assert set(q()[0].keys()) == all_fields assert set(q(fields=["cat"])[0].keys()) == {"_id", "cat"} - assert set(q(fields=["date", "title"])[0].keys()) == {"_id", "date", "title"} + assert set(q(fields=["date", "title"])[0].keys()) == {"_id", "date"} def test_aggregate(client, index_docs, user): @@ -93,12 +57,8 @@ def test_aggregate(client, index_docs, user): "aggregations": [{"field": "i", "function": "avg"}], }, ) - assert dictset(r["data"]) == dictset( - [{"avg_i": 1.5, "n": 2, "subcat": "x"}, {"avg_i": 21.0, "n": 2, "subcat": "y"}] - ) - assert r["meta"]["aggregations"] == [ - {"field": "i", "function": "avg", "type": "long", "name": "avg_i"} - ] + assert dictset(r["data"]) == dictset([{"avg_i": 1.5, "n": 2, "subcat": "x"}, {"avg_i": 21.0, "n": 2, "subcat": "y"}]) + assert r["meta"]["aggregations"] == [{"field": "i", "function": "avg", "type": "number", "name": "avg_i"}] # test filtered aggregate r = post_json( @@ -132,32 +92,22 @@ def test_multiple_index(client, index_docs, index, user): upload( index, [{"text": "also a text", "i": -1, "cat": "c"}], - fields={"cat": "keyword", "i": "long"}, + fields={ + "text": CreateField(elastic_type="text"), + "cat": CreateField(elastic_type="keyword"), + "i": CreateField(elastic_type="long"), + }, ) indices = f"{index},{index_docs}" - assert ( - len( - get_json( - client, - f"/index/{indices}/documents", - user=user, - params=dict(fields="_id"), - )["results"] - ) - == 5 - ) - assert ( - len( - post_json( - client, - f"/index/{indices}/query", - user=user, - expected=200, - json=dict(fields=["_id"]), - )["results"] - ) - == 5 + + r = post_json( + client, + f"/index/{indices}/query", + user=user, + expected=200, + json=dict(fields=["_id", "cat", "i"]), ) + assert len(r["results"]) == 5 r = post_json( client, @@ -166,9 +116,8 @@ def test_multiple_index(client, index_docs, index, user): json={"axes": [{"field": "cat"}], "fields": ["_id"]}, expected=200, ) - assert dictset(r["data"]) == dictset( - [{"cat": "a", "n": 3}, {"n": 1, "cat": "b"}, {"n": 1, "cat": "c"}] - ) + print(r) + assert dictset(r["data"]) == dictset([{"cat": "a", "n": 3}, {"n": 1, "cat": "b"}, {"n": 1, "cat": "c"}]) def test_aggregate_datemappings(client, index_docs, user): @@ -203,8 +152,8 @@ def test_aggregate_datemappings(client, index_docs, user): def test_query_tags(client, index_docs, user): def tags(): return { - doc["_id"]: doc["tag"] - for doc in query_documents(index_docs, fields=["tag"]).data + doc["which"]: doc["tag"] + for doc in query_documents(index_docs, fields=[FieldSpec(name="tag"), FieldSpec(name="which")]).data if doc.get("tag") } @@ -217,7 +166,7 @@ def tags(): json=dict(action="add", field="tag", tag="x", filters={"cat": "a"}), ) refresh_index(index_docs) - assert tags() == {"0": ["x"], "1": ["x"], "2": ["x"]} + assert tags() == {0: ["x"], 1: ["x"], 2: ["x"]} post_json( client, f"/index/{index_docs}/tags_update", @@ -226,13 +175,13 @@ def tags(): json=dict(action="remove", field="tag", tag="x", queries=["text"]), ) refresh_index(index_docs) - assert tags() == {"2": ["x"]} + assert tags() == {2: ["x"]} post_json( client, f"/index/{index_docs}/tags_update", user=user, expected=204, - json=dict(action="add", field="tag", tag="y", ids=["1", "2"]), + json=dict(action="add", field="tag", tag="y", filters={"which": {"gte": 1, "lte": 2}}), ) refresh_index(index_docs) - assert tags() == {"1": ["y"], "2": ["x", "y"]} + assert tags() == {1: ["y"], 2: ["x", "y"]} diff --git a/tests/test_api_user.py b/tests/test_api_user.py index 37b6bd8..f15f206 100644 --- a/tests/test_api_user.py +++ b/tests/test_api_user.py @@ -29,6 +29,10 @@ def test_auth(client: TestClient, user, admin, index): assert client.get(f"/index/{index}", headers=build_headers(admin)).status_code == 200 with set_auth(AuthOptions.authorized_users_only): # Only users with a index-level role can access other indices (even as guest) + # KW: I don't understand what this means. Do we need to check every index? + # Now changed it so that only users with a server level role can access other indices as guest. + # In other words, in this auth mode you either need index level authorization or server level + # authorization with guest access. (this did pass the test) set_guest_role(index, Role.READER) refresh() assert client.get(f"/index/{index}").status_code == 401 diff --git a/tests/test_elastic.py b/tests/test_elastic.py index c98cdbf..498ca4d 100644 --- a/tests/test_elastic.py +++ b/tests/test_elastic.py @@ -7,7 +7,7 @@ update_document, update_tag_by_query, ) -from amcat4.fields import update_fields, get_fields, get_field_values +from amcat4.fields import update_fields, get_fields, field_values from amcat4.query import query_documents from tests.conftest import upload @@ -51,7 +51,7 @@ def test_fields(index): def test_values(index): """Can we get values for a specific field""" upload(index, [dict(bla=x) for x in ["odd", "even", "even"] * 10], fields={"bla": "keyword"}) - assert set(get_field_values(index, "bla", 10)) == {"odd", "even"} + assert set(field_values(index, "bla", 10)) == {"odd", "even"} def test_update(index_docs): diff --git a/tests/test_index.py b/tests/test_index.py index 517b9f6..c5c3685 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -24,8 +24,8 @@ set_global_role, set_guest_role, set_role, - update_fields, ) +from amcat4.fields import update_fields from amcat4.models import Field from tests.tools import refresh From 8506948cc3605a339f3dc0cbae39210f51cfe498 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Tue, 26 Mar 2024 17:01:46 +0100 Subject: [PATCH 37/80] unit tests jeej --- amcat4/api/index.py | 13 +++++++++--- amcat4/fields.py | 25 ++++++++++++++++++++--- amcat4/index.py | 5 +++-- tests/conftest.py | 27 ++++++------------------- tests/test_api_query.py | 18 ++++++++--------- tests/test_elastic.py | 44 +++++++++++++++++++++++++++++++++-------- tests/test_query.py | 4 ++-- 7 files changed, 88 insertions(+), 48 deletions(-) diff --git a/amcat4/api/index.py b/amcat4/api/index.py index f14cd60..2a09830 100644 --- a/amcat4/api/index.py +++ b/amcat4/api/index.py @@ -180,8 +180,10 @@ def upload_documents( ix: str, documents: Annotated[list[dict[str, Any]], Body(description="The documents to upload")], fields: Annotated[ - dict[str, CreateField] | None, - Body(description="If a field in documents does not yet exist, you can create it on the spot"), + dict[str, str | CreateField] | None, + Body( + description="If a field in documents does not yet exist, you can create it on the spot. If you only need to specify the type, and use the default settings, you can use the short form: {field: type}" + ), ] = None, operation: Annotated[ Literal["index", "update", "create"], @@ -268,7 +270,12 @@ def delete_document(ix: str, docid: str, user: str = Depends(authenticated_user) @app_index.post("/{ix}/fields") def create_fields( ix: str, - fields: Annotated[dict[str, CreateField], Body(description="")], + fields: Annotated[ + dict[str, str | CreateField], + Body( + description="Either a dictionary that maps field names to field specifications ({field: {elastic_type: text, identifier:True}}), or a simplified version that only specifies the type ({field: type})" + ), + ], user: str = Depends(authenticated_user), ): """ diff --git a/amcat4/fields.py b/amcat4/fields.py index cce65c3..e52c6be 100644 --- a/amcat4/fields.py +++ b/amcat4/fields.py @@ -12,7 +12,7 @@ system index """ -from typing import Any, Iterator +from typing import Any, Iterator, get_args, cast from elasticsearch import NotFoundError @@ -77,6 +77,17 @@ def get_default_field(elastic_type: ElasticType): return Field(type=amcat_type, elastic_type=elastic_type, metareader=get_default_metareader(amcat_type)) +def _standardize_createfields(fields: dict[str, str | CreateField]) -> dict[str, CreateField]: + sfields = {} + for k, v in fields.items(): + if isinstance(v, str): + assert v in get_args(ElasticType), f"Unknown elastic type {v}" + sfields[k] = CreateField(elastic_type=cast(ElasticType, v)) + else: + sfields[k] = v + return sfields + + def coerce_type(value: Any, elastic_type: ElasticType): """ Coerces values into the respective type in elastic @@ -92,16 +103,24 @@ def coerce_type(value: Any, elastic_type: ElasticType): return float(value) # TODO: check coercion / validation for object, vector and geo types + if elastic_type in ["object", "flattened", "nested"]: + return value + if elastic_type in ["dense_vector"]: + return value + if elastic_type in ["geo_point"]: + return value + return value -def create_fields(index: str, fields: dict[str, CreateField]): +def create_fields(index: str, fields: dict[str, str | CreateField]): mapping: dict[str, Any] = {} current_fields = {k: v for k, v in _get_index_fields(index)} new_fields: dict[str, CreateField] = {} + sfields = _standardize_createfields(fields) - for field, settings in fields.items(): + for field, settings in sfields.items(): if field not in current_fields: new_fields[field] = settings diff --git a/amcat4/index.py b/amcat4/index.py index 3d93500..53151e3 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -35,7 +35,7 @@ import collections from datetime import datetime from enum import IntEnum -from typing import Any, Iterable, Optional, Literal +from typing import Any, Iterable, Optional, Literal, cast, get_args import hashlib import json @@ -48,6 +48,7 @@ from amcat4.config import AuthOptions, get_settings from amcat4.elastic import es from amcat4.fields import ( + _standardize_createfields, coerce_type, create_fields, get_fields, @@ -423,7 +424,7 @@ def _get_hash(document: dict, field_settings: dict[str, Field]) -> str: def upload_documents( - index: str, documents: list[dict[str, Any]], fields: dict[str, CreateField] | None = None, op_type="index" + index: str, documents: list[dict[str, Any]], fields: dict[str, str | CreateField] | None = None, op_type="index" ): """ Upload documents to this index diff --git a/tests/conftest.py b/tests/conftest.py index 5bd9c37..9aa5f84 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -136,7 +136,7 @@ def guest_index(): delete_index(index, ignore_missing=True) -def upload(index: str, docs: list[dict[str, Any]], fields: dict[str, CreateField] | None = None): +def upload(index: str, docs: list[dict[str, Any]], fields: dict[str, str | CreateField] | None = None): """ Upload these docs to the index, giving them an incremental id, and flush """ @@ -146,24 +146,10 @@ def upload(index: str, docs: list[dict[str, Any]], fields: dict[str, CreateField TEST_DOCUMENTS = [ + {"_id": 0, "cat": "a", "subcat": "x", "i": 1, "date": "2018-01-01", "text": "this is a text", "title": "title"}, + {"_id": 1, "cat": "a", "subcat": "x", "i": 2, "date": "2018-02-01", "text": "a test text", "title": "title"}, { - "which": 0, - "cat": "a", - "subcat": "x", - "i": 1, - "date": "2018-01-01", - "text": "this is a text", - }, - { - "which": 1, - "cat": "a", - "subcat": "x", - "i": 2, - "date": "2018-02-01", - "text": "a test text", - }, - { - "which": 2, + "_id": 2, "cat": "a", "subcat": "y", "i": 11, @@ -172,7 +158,7 @@ def upload(index: str, docs: list[dict[str, Any]], fields: dict[str, CreateField "title": "bla", }, { - "which": 3, + "_id": 3, "cat": "b", "subcat": "y", "i": 31, @@ -189,12 +175,11 @@ def populate_index(index): TEST_DOCUMENTS, fields={ "text": CreateField(elastic_type="text"), - "title": CreateField(elastic_type="keyword"), + "title": CreateField(elastic_type="text"), "date": CreateField(elastic_type="date"), "cat": CreateField(elastic_type="keyword"), "subcat": CreateField(elastic_type="keyword"), "i": CreateField(elastic_type="long"), - "which": CreateField(elastic_type="short"), }, ) return TEST_DOCUMENTS diff --git a/tests/test_api_query.py b/tests/test_api_query.py index 5f54f20..11c7b69 100644 --- a/tests/test_api_query.py +++ b/tests/test_api_query.py @@ -10,7 +10,7 @@ def q(**body): return post_json(client, f"/index/{index_docs}/query", user=user, expected=200, json=body)["results"] def qi(**query_string): - return {doc["which"] for doc in q(**query_string)} + return {int(doc["_id"]) for doc in q(**query_string)} # Query strings assert qi(queries="text") == {0, 1} @@ -28,10 +28,10 @@ def qi(**query_string): assert qi(filters={"cat": {"values": ["a"]}}) == {0, 1, 2} # Can we request specific fields? - all_fields = {"_id", "cat", "subcat", "i", "date", "text", "which"} + all_fields = {"_id", "cat", "subcat", "i", "date", "text", "title"} assert set(q()[0].keys()) == all_fields assert set(q(fields=["cat"])[0].keys()) == {"_id", "cat"} - assert set(q(fields=["date", "title"])[0].keys()) == {"_id", "date"} + assert set(q(fields=["date", "title"])[0].keys()) == {"_id", "date", "title"} def test_aggregate(client, index_docs, user): @@ -152,8 +152,8 @@ def test_aggregate_datemappings(client, index_docs, user): def test_query_tags(client, index_docs, user): def tags(): return { - doc["which"]: doc["tag"] - for doc in query_documents(index_docs, fields=[FieldSpec(name="tag"), FieldSpec(name="which")]).data + doc["_id"]: doc["tag"] + for doc in query_documents(index_docs, fields=[FieldSpec(name="tag")]).data if doc.get("tag") } @@ -166,7 +166,7 @@ def tags(): json=dict(action="add", field="tag", tag="x", filters={"cat": "a"}), ) refresh_index(index_docs) - assert tags() == {0: ["x"], 1: ["x"], 2: ["x"]} + assert tags() == {"0": ["x"], "1": ["x"], "2": ["x"]} post_json( client, f"/index/{index_docs}/tags_update", @@ -175,13 +175,13 @@ def tags(): json=dict(action="remove", field="tag", tag="x", queries=["text"]), ) refresh_index(index_docs) - assert tags() == {2: ["x"]} + assert tags() == {"2": ["x"]} post_json( client, f"/index/{index_docs}/tags_update", user=user, expected=204, - json=dict(action="add", field="tag", tag="y", filters={"which": {"gte": 1, "lte": 2}}), + json=dict(action="add", field="tag", tag="y", ids=["1", "2"]), ) refresh_index(index_docs) - assert tags() == {1: ["y"], 2: ["x", "y"]} + assert tags() == {"1": ["y"], "2": ["x", "y"]} diff --git a/tests/test_elastic.py b/tests/test_elastic.py index 498ca4d..4b2a6ba 100644 --- a/tests/test_elastic.py +++ b/tests/test_elastic.py @@ -7,7 +7,8 @@ update_document, update_tag_by_query, ) -from amcat4.fields import update_fields, get_fields, field_values +from amcat4.fields import create_fields, update_fields, get_fields, field_values +from amcat4.models import CreateField, FieldSpec from amcat4.query import query_documents from tests.conftest import upload @@ -21,7 +22,7 @@ def test_upload_retrieve_document(index): _id="test", term_tfidf=[{"term": "test", "value": 0.2}, {"term": "value", "value": 0.3}], ) - upload_documents(index, [a]) + upload_documents(index, [a], fields={"text": "text", "title": "text", "date": "date", "term_tfidf": "nested"}) d = get_document(index, "test") assert d["title"] == a["title"] assert d["term_tfidf"] == a["term_tfidf"] @@ -30,11 +31,11 @@ def test_upload_retrieve_document(index): def test_data_coerced(index): """Are field values coerced to the correct field type""" - update_fields(index, {"i": "long"}) - a = dict(_id="DoccyMcDocface", text="text", title="test-numeric", date="2022-12-13", i="1") + create_fields(index, {"i": "long", "x": "double", "title": "text", "date": "date", "text": "text"}) + a = dict(_id="DoccyMcDocface", text="text", title="test-numeric", date="2022-12-13", i="1", x="1.1") upload_documents(index, [a]) d = get_document(index, "DoccyMcDocface") - assert isinstance(d["i"], float) + assert isinstance(d["i"], int) a = dict(text="text", title=1, date="2022-12-13") upload_documents(index, [a]) d = get_document(index, "DoccyMcDocface") @@ -43,9 +44,19 @@ def test_data_coerced(index): def test_fields(index): """Can we get the fields from an index""" + create_fields(index, {"title": "text", "date": "date", "text": "text", "url": "keyword"}) fields = get_fields(index) assert set(fields.keys()) == {"title", "date", "text", "url"} - assert fields["date"]["type"] == "date" + assert fields["title"].type == "text" + assert fields["date"].type == "date" + + # default settings + assert fields["date"].identifier == False + assert fields["date"].client_settings is not None + + # default settings depend on the type + assert fields["date"].metareader.access == "read" + assert fields["text"].metareader.access == "none" def test_values(index): @@ -68,7 +79,7 @@ def q(*ids): def tags(): return { doc["_id"]: doc["tag"] - for doc in query_documents(index_docs, fields=["tag"]).data + for doc in query_documents(index_docs, fields=[FieldSpec(name="tag")]).data if "tag" in doc and doc["tag"] is not None } @@ -89,9 +100,26 @@ def tags(): def test_deduplication(index): doc = {"title": "titel", "text": "text", "date": datetime(2020, 1, 1)} - upload_documents(index, [doc]) + upload_documents(index, [doc], fields={"title": "text", "text": "text", "date": "date"}) refresh_index(index) assert query_documents(index).total_count == 1 upload_documents(index, [doc]) refresh_index(index) assert query_documents(index).total_count == 1 + + +def test_identifier_deduplication(index): + doc = {"url": "http://", "text": "text"} + upload_documents(index, [doc], fields={"url": CreateField(elastic_type="wildcard", identifier=True), "text": "text"}) + refresh_index(index) + assert query_documents(index).total_count == 1 + + doc2 = {"url": "http://", "text": "text2"} + upload_documents(index, [doc2]) + refresh_index(index) + assert query_documents(index).total_count == 1 + + doc3 = {"url": "http://2", "text": "text"} + upload_documents(index, [doc3]) + refresh_index(index) + assert query_documents(index).total_count == 2 diff --git a/tests/test_query.py b/tests/test_query.py index 30b43a1..b257348 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -65,7 +65,7 @@ def test_fields(index_docs): def test_highlight(index): words = "The error of regarding functional notions is not quite equivalent to" text = f"{words} a test document. {words} other text documents. {words} you!" - upload(index, [dict(title="Een test titel", text=text)]) + upload(index, [dict(title="Een test titel", text=text)], fields={"title": "text", "text": "text"}) res = query.query_documents( index, fields=[FieldSpec(name="title"), FieldSpec(name="text")], queries={"1": "te*"}, highlight=True ) @@ -89,7 +89,7 @@ def test_highlight(index): def test_query_multiple_index(index_docs, index): - upload(index, [{"text": "also a text", "i": -1}], fields={"i": UpdateField(type="long")}) + upload(index, [{"text": "also a text", "i": -1}], fields={"i": "long", "text": "text"}) docs = query.query_documents([index_docs, index]) assert docs is not None assert len(docs.data) == 5 From c48be686c89a67cf54050dd7afa26011a3656c58 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Wed, 27 Mar 2024 12:22:17 +0100 Subject: [PATCH 38/80] don't overwrite guest role to none if absent --- amcat4/index.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/amcat4/index.py b/amcat4/index.py index 53151e3..a8541a6 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -120,7 +120,7 @@ def _index_from_elastic(index): id=index["_id"], name=src.get("name", index["_id"]), description=src.get("description"), - guest_role=Role[guest_role], + guest_role=Role[guest_role] if guest_role in Role.__members__ else Role.NONE, archived=src.get("archived"), roles=_roles_from_elastic(src.get("roles", [])), summary_field=src.get("summary_field"), @@ -277,7 +277,7 @@ def modify_index( doc = dict( name=name, description=description, - guest_role=guest_role.name if guest_role is not None else "NONE", + guest_role=guest_role.name if guest_role is not None else None, summary_field=summary_field, archived=archived, ) From c3fdca6fec81979d21606f2686ac195d03bf726b Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Wed, 27 Mar 2024 18:34:23 +0100 Subject: [PATCH 39/80] field stuff --- amcat4/__main__.py | 1 - amcat4/aggregate.py | 2 - amcat4/api/index.py | 17 +++-- amcat4/config.py | 6 +- amcat4/fields.py | 120 +++++++++++++++++++++-------------- amcat4/index.py | 12 +--- amcat4/models.py | 17 ++--- tests/conftest.py | 18 +++--- tests/test_aggregate.py | 2 +- tests/test_api_documents.py | 6 +- tests/test_api_index.py | 9 ++- tests/test_api_pagination.py | 4 +- tests/test_api_query.py | 8 +-- tests/test_elastic.py | 2 +- tests/test_index.py | 20 +++--- 15 files changed, 127 insertions(+), 117 deletions(-) diff --git a/amcat4/__main__.py b/amcat4/__main__.py index 7c21c39..cd6eaa1 100644 --- a/amcat4/__main__.py +++ b/amcat4/__main__.py @@ -182,7 +182,6 @@ def list_users(_args): def config_amcat(args): settings = get_settings() - settings_fields = settings.model_fields # Not a useful entry in an actual env_file print(f"Reading/writing settings from {settings.env_file}") for fieldname, fieldinfo in settings.model_fields.items(): diff --git a/amcat4/aggregate.py b/amcat4/aggregate.py index 95900d1..782bdf0 100644 --- a/amcat4/aggregate.py +++ b/amcat4/aggregate.py @@ -4,8 +4,6 @@ import copy from datetime import datetime -from itertools import islice -import json from typing import Any, Mapping, Iterable, Union, Tuple, Sequence, List, Dict from amcat4.date_mappings import interval_mapping diff --git a/amcat4/api/index.py b/amcat4/api/index.py index 2a09830..983a25f 100644 --- a/amcat4/api/index.py +++ b/amcat4/api/index.py @@ -1,7 +1,6 @@ """API Endpoints for document and index management.""" from http import HTTPStatus -from operator import is_ from typing import Annotated, Any, Literal import elasticsearch @@ -15,7 +14,7 @@ from amcat4.index import refresh_system_index, remove_role, set_role from amcat4.fields import field_values, field_stats -from amcat4.models import CreateField, ElasticType, Field, UpdateField +from amcat4.models import CreateField, ElasticType, UpdateField app_index = APIRouter(prefix="/index", tags=["index"]) @@ -149,7 +148,7 @@ def archive_index( is_archived = d.archived is not None if is_archived == archived: return - archived_date = datetime.now() if archived else None + archived_date = str(datetime.now()) if archived else None index.modify_index(ix, archived=archived_date) except index.IndexDoesNotExist: @@ -180,9 +179,11 @@ def upload_documents( ix: str, documents: Annotated[list[dict[str, Any]], Body(description="The documents to upload")], fields: Annotated[ - dict[str, str | CreateField] | None, + dict[str, ElasticType | CreateField] | None, Body( - description="If a field in documents does not yet exist, you can create it on the spot. If you only need to specify the type, and use the default settings, you can use the short form: {field: type}" + description="If a field in documents does not yet exist, you can create it on the spot. " + "If you only need to specify the type, and use the default settings, " + "you can use the short form: {field: type}" ), ] = None, operation: Annotated[ @@ -271,9 +272,11 @@ def delete_document(ix: str, docid: str, user: str = Depends(authenticated_user) def create_fields( ix: str, fields: Annotated[ - dict[str, str | CreateField], + dict[str, ElasticType | CreateField], Body( - description="Either a dictionary that maps field names to field specifications ({field: {elastic_type: text, identifier:True}}), or a simplified version that only specifies the type ({field: type})" + description="Either a dictionary that maps field names to field specifications" + "({field: {type: text, identifier:True}}), " + "or a simplified version that only specifies the type ({field: type})" ), ], user: str = Depends(authenticated_user), diff --git a/amcat4/config.py b/amcat4/config.py index b560f74..ecb4bdc 100644 --- a/amcat4/config.py +++ b/amcat4/config.py @@ -10,7 +10,7 @@ import functools from enum import Enum from pathlib import Path -from typing import Annotated +from typing import Annotated, Any from class_doc import extract_docs_from_cls_obj from dotenv import load_dotenv from pydantic import model_validator, Field @@ -40,7 +40,7 @@ def validate(cls, value: str): return f"{value} is not a valid authorization option. Choose one of {{{options}}}" -# As far as I know, there is no elegant built-in way to set to __doc__ of an enum? +# Set the __doc__ attribute of each AuthOptions enum member using extract_docs_from_cls_obj for field, doc in extract_docs_from_cls_obj(AuthOptions).items(): AuthOptions[field].__doc__ = "\n".join(doc) @@ -113,7 +113,7 @@ class Settings(BaseSettings): admin_password: Annotated[str | None, Field()] = None @model_validator(mode="after") - def set_ssl(self) -> "Settings": + def set_ssl(self: Any) -> "Settings": if not self.elastic_host: self.elastic_host = ("https" if self.elastic_password else "http") + "://localhost:9200" if not self.elastic_verify_ssl: diff --git a/amcat4/fields.py b/amcat4/fields.py index e52c6be..c15bbb3 100644 --- a/amcat4/fields.py +++ b/amcat4/fields.py @@ -1,8 +1,10 @@ """ We have two types of fields: -- Elastic fields are the fields used under the hood by elastic. (https://www.elastic.co/guide/en/elasticsearch/reference/current/mapping-types.html +- Elastic fields are the fields used under the hood by elastic. + (https://www.elastic.co/guide/en/elasticsearch/reference/current/mapping-types.html These are stored in the Mapping of an index -- Amcat fields (Field) are the fields are seen by the amcat user. They use a simplified type, and contain additional information such as metareader access +- Amcat fields (Field) are the fields are seen by the amcat user. They use a simplified type, and contain additional + information such as metareader access These are stored in the system index. We need to make sure that: @@ -12,7 +14,8 @@ system index """ -from typing import Any, Iterator, get_args, cast +from hmac import new +from typing import Any, Iterator, Mapping, get_args, cast from elasticsearch import NotFoundError @@ -20,12 +23,12 @@ # from amcat4.api.common import py2dict from amcat4.config import get_settings from amcat4.elastic import es -from amcat4.models import AmcatType, CreateField, ElasticType, Field, UpdateField, updateField, FieldMetareaderAccess +from amcat4.models import TypeGroup, CreateField, ElasticType, Field, UpdateField, FieldMetareaderAccess # given an elastic field type, infer # (this is relevant if we are importing an index that does not yet have ) -TYPEMAP_ES_TO_AMCAT: dict[ElasticType, AmcatType] = { +TYPEMAP_ES_TO_AMCAT: dict[ElasticType, TypeGroup] = { # TEXT fields "text": "text", "annotated_text": "text", @@ -62,27 +65,27 @@ } -def get_default_metareader(amcat_type: AmcatType): - if amcat_type in ["boolean", "number", "date"]: +def get_default_metareader(type_group: TypeGroup): + if type_group in ["boolean", "number", "date"]: return FieldMetareaderAccess(access="read") return FieldMetareaderAccess(access="none") def get_default_field(elastic_type: ElasticType): - amcat_type = TYPEMAP_ES_TO_AMCAT.get(elastic_type) - if amcat_type is None: + type_group = TYPEMAP_ES_TO_AMCAT.get(elastic_type) + if type_group is None: raise ValueError(f"Invalid elastic type: {elastic_type}") - return Field(type=amcat_type, elastic_type=elastic_type, metareader=get_default_metareader(amcat_type)) + return Field(type_group=type_group, type=elastic_type, metareader=get_default_metareader(type_group)) -def _standardize_createfields(fields: dict[str, str | CreateField]) -> dict[str, CreateField]: +def _standardize_createfields(fields: Mapping[str, ElasticType | CreateField]) -> dict[str, CreateField]: sfields = {} for k, v in fields.items(): if isinstance(v, str): assert v in get_args(ElasticType), f"Unknown elastic type {v}" - sfields[k] = CreateField(elastic_type=cast(ElasticType, v)) + sfields[k] = CreateField(type=cast(ElasticType, v)) else: sfields[k] = v return sfields @@ -113,38 +116,57 @@ def coerce_type(value: Any, elastic_type: ElasticType): return value -def create_fields(index: str, fields: dict[str, str | CreateField]): +def create_fields(index: str, fields: Mapping[str, ElasticType | CreateField]): mapping: dict[str, Any] = {} - current_fields = {k: v for k, v in _get_index_fields(index)} + current_fields = get_fields(index) + + new_fields: dict[str, Field] = {} - new_fields: dict[str, CreateField] = {} sfields = _standardize_createfields(fields) for field, settings in sfields.items(): - if field not in current_fields: - new_fields[field] = settings - - if TYPEMAP_ES_TO_AMCAT.get(settings.elastic_type) is None: - raise ValueError(f"Field type {settings.elastic_type} not supported by AmCAT") - - current_type = current_fields.get(field) - if current_type is not None: - if current_type != settings.elastic_type: + if TYPEMAP_ES_TO_AMCAT.get(settings.type) is None: + raise ValueError(f"Field type {settings.type} not supported by AmCAT") + + current = current_fields.get(field) + if current is not None: + # fields can already exist. For example, a scraper might include the field types in every + # upload request. If a field already exists, we'll ignore it, but we will throw an error + # if static settings (field type, identifier) do not match + if current.type != settings.type: raise ValueError( - f"Field '{field}' already exists with type '{current_type}'. Cannot change type to '{settings.elastic_type}'" + f"Field '{field}' already exists with type '{current.type}'. " f"Cannot change type to '{settings.type}'" ) - continue - - mapping[field] = {"type": settings.elastic_type} - - if settings.elastic_type in ["date"]: - mapping[field]["format"] = "strict_date_optional_time" - - es().indices.put_mapping(index=index, properties=mapping) - update_fields(index, new_fields) + if current.identifier and not settings.identifier: + raise ValueError(f"Field '{field}' is an identifier, cannot change to non-identifier") + if not current.identifier and settings.identifier: + raise ValueError(f"Field '{field}' is not an identifier, cannot change to identifier") + new_fields[field] = current + else: + # if field does not exist, we add it to both the mapping and the system index + mapping[field] = {"type": settings.type} + if settings.type in ["date"]: + mapping[field]["format"] = "strict_date_optional_time" + + new_fields[field] = Field( + type=settings.type, + type_group=TYPEMAP_ES_TO_AMCAT[settings.type], + identifier=settings.identifier, + metareader=settings.metareader or get_default_metareader(TYPEMAP_ES_TO_AMCAT[settings.type]), + client_settings=settings.client_settings or {}, + ) + print(mapping) + if len(mapping) > 0: + es().indices.put_mapping(index=index, properties=mapping) + es().update( + index=get_settings().system_index, + id=index, + doc=dict(fields=_fields_to_elastic(new_fields)), + ) def _fields_to_elastic(fields: dict[str, Field]) -> list[dict]: + # some additional validation return [{"field": field, "settings": settings.model_dump()} for field, settings in fields.items()] @@ -154,30 +176,31 @@ def _fields_from_elastic( return {fs["field"]: Field.model_validate(fs["settings"]) for fs in fields} -def update_fields(index: str, new_fields: dict[str, UpdateField] | dict[str, Field] | dict[str, CreateField]): +def update_fields(index: str, fields: dict[str, UpdateField]): """ Set the fields settings for this index. Only updates fields that - already exist. type and elastic_type cannot be changed. + already exist. Only keys in UpdateField can be updated (not type or client_settings) """ - system_index = get_settings().system_index - fields = get_fields(index) + current_fields = get_fields(index) - for field, new_settings in new_fields.items(): - current = fields.get(field) + for field, new_settings in fields.items(): + current = current_fields.get(field) if current is None: raise ValueError(f"Field {field} does not exist") - if current.type != "text": - if new_settings.metareader and new_settings.metareader.access == "snippet": + if new_settings.metareader is not None: + if current.type != "text" and new_settings.metareader.access == "snippet": raise ValueError(f"Field {field} is not of type text, cannot set metareader access to snippet") + current_fields[field].metareader = new_settings.metareader - fields[field] = updateField(field=current, update=new_settings) + if new_settings.client_settings is not None: + current_fields[field].client_settings = new_settings.client_settings es().update( - index=system_index, + index=get_settings().system_index, id=index, - doc=dict(fields=_fields_to_elastic(fields)), + doc=dict(fields=_fields_to_elastic(current_fields)), ) @@ -210,9 +233,9 @@ def get_fields(index: str) -> dict[str, Field]: update_system_index = False for field, elastic_type in _get_index_fields(index): - amcat_type = TYPEMAP_ES_TO_AMCAT.get(elastic_type) + type_group = TYPEMAP_ES_TO_AMCAT.get(elastic_type) - if amcat_type is None: + if type_group is None: # skip over unsupported elastic fields. # (TODO: also return warning to client?) continue @@ -237,7 +260,8 @@ def field_values(index: str, field: str, size: int) -> list[str]: """ Get the values for a given field (e.g. to populate list of filter values on keyword field) Results are sorted descending by document frequency - see: https://www.elastic.co/guide/en/elasticsearch/reference/7.4/search-aggregations-bucket-terms-aggregation.html#search-aggregations-bucket-terms-aggregation-order + see: https://www.elastic.co/guide/en/elasticsearch/reference/7.4/search-aggregations-bucket-terms-aggregation.html + #search-aggregations-bucket-terms-aggregation-order :param index: The index :param field: The field name diff --git a/amcat4/index.py b/amcat4/index.py index a8541a6..3877051 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -33,22 +33,19 @@ """ import collections -from datetime import datetime from enum import IntEnum -from typing import Any, Iterable, Optional, Literal, cast, get_args +from typing import Any, Iterable, Mapping, Optional, Literal import hashlib import json import elasticsearch.helpers from elasticsearch import NotFoundError -from httpx import get # from amcat4.api.common import py2dict from amcat4.config import AuthOptions, get_settings from amcat4.elastic import es from amcat4.fields import ( - _standardize_createfields, coerce_type, create_fields, get_fields, @@ -272,13 +269,11 @@ def modify_index( description: Optional[str] = None, guest_role: Optional[Role] = None, archived: Optional[str] = None, - summary_field=None, ): doc = dict( name=name, description=description, guest_role=guest_role.name if guest_role is not None else None, - summary_field=summary_field, archived=archived, ) @@ -424,7 +419,7 @@ def _get_hash(document: dict, field_settings: dict[str, Field]) -> str: def upload_documents( - index: str, documents: list[dict[str, Any]], fields: dict[str, str | CreateField] | None = None, op_type="index" + index: str, documents: list[dict[str, Any]], fields: Mapping[str, ElasticType | CreateField] | None = None, op_type="index" ): """ Upload documents to this index @@ -433,7 +428,6 @@ def upload_documents( :param documents: A sequence of article dictionaries :param fields: A mapping of fieldname:UpdateField for field types """ - if fields: create_fields(index, fields) @@ -446,7 +440,7 @@ def es_actions(index, documents, op_type): continue if key not in field_settings: raise ValueError(f"The type for field '{key}' is not yet specified") - document[key] = coerce_type(document[key], field_settings[key].elastic_type) + document[key] = coerce_type(document[key], field_settings[key].type) if "_id" not in document: document["_id"] = _get_hash(document, field_settings) yield {"_op_type": op_type, "_index": index, **document} diff --git a/amcat4/models.py b/amcat4/models.py index 85d2e88..fd60eb0 100644 --- a/amcat4/models.py +++ b/amcat4/models.py @@ -1,9 +1,9 @@ import pydantic from pydantic import BaseModel -from typing import Annotated, Any, Literal, NewType +from typing import Annotated, Any, Literal -AmcatType = Literal["text", "date", "boolean", "keyword", "number", "object", "vector", "geo"] +TypeGroup = Literal["text", "date", "boolean", "keyword", "number", "object", "vector", "geo"] ElasticType = Literal[ "text", "annotated_text", @@ -54,8 +54,8 @@ class Field(BaseModel): """Settings for a field. Some settings, such as metareader, have a strict model because they are used server side. Others, such as client_settings, are free-form and can be used by the client to store settings.""" - type: AmcatType - elastic_type: ElasticType + type: ElasticType + type_group: TypeGroup identifier: bool = False metareader: FieldMetareaderAccess = FieldMetareaderAccess() client_settings: dict[str, Any] = {} @@ -64,7 +64,7 @@ class Field(BaseModel): class CreateField(BaseModel): """Model for creating a field""" - elastic_type: ElasticType + type: ElasticType identifier: bool = False metareader: FieldMetareaderAccess | None = None client_settings: dict[str, Any] | None = None @@ -77,13 +77,6 @@ class UpdateField(BaseModel): client_settings: dict[str, Any] | None = None -def updateField(field: Field, update: UpdateField | Field | CreateField): - - for key in update.model_fields_set: - setattr(field, key, getattr(update, key)) - return field - - FilterValue = str | int diff --git a/tests/conftest.py b/tests/conftest.py index 9aa5f84..f3beb29 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,7 +16,7 @@ set_global_role, upload_documents, ) -from amcat4.models import CreateField +from amcat4.models import CreateField, ElasticType from tests.middlecat_keypair import PUBLIC_KEY UNITS = [ @@ -136,7 +136,7 @@ def guest_index(): delete_index(index, ignore_missing=True) -def upload(index: str, docs: list[dict[str, Any]], fields: dict[str, str | CreateField] | None = None): +def upload(index: str, docs: list[dict[str, Any]], fields: dict[str, ElasticType | CreateField] | None = None): """ Upload these docs to the index, giving them an incremental id, and flush """ @@ -174,12 +174,12 @@ def populate_index(index): index, TEST_DOCUMENTS, fields={ - "text": CreateField(elastic_type="text"), - "title": CreateField(elastic_type="text"), - "date": CreateField(elastic_type="date"), - "cat": CreateField(elastic_type="keyword"), - "subcat": CreateField(elastic_type="keyword"), - "i": CreateField(elastic_type="long"), + "text": CreateField(type="text"), + "title": CreateField(type="text"), + "date": CreateField(type="date"), + "cat": CreateField(type="keyword"), + "subcat": CreateField(type="keyword"), + "i": CreateField(type="long"), }, ) return TEST_DOCUMENTS @@ -203,7 +203,7 @@ def index_many(): upload( index, [dict(id=i, pagenr=abs(10 - i), text=text) for (i, text) in enumerate(["odd", "even"] * 10)], - fields={"id": CreateField(elastic_type="long"), "pagenr": CreateField(elastic_type="long")}, + fields={"id": CreateField(type="long"), "pagenr": CreateField(type="long")}, ) yield index delete_index(index, ignore_missing=True) diff --git a/tests/test_aggregate.py b/tests/test_aggregate.py index c187614..c1e41cc 100644 --- a/tests/test_aggregate.py +++ b/tests/test_aggregate.py @@ -115,7 +115,7 @@ def test_aggregate_datefunctions(index: str): "2018-03-07T23:59:00", # wednesday evening ] ] - upload(index, docs, fields=dict(date=CreateField(elastic_type="date"))) + upload(index, docs, fields=dict(date=CreateField(type="date"))) assert q(Axis("date", interval="day")) == { date(2018, 1, 1): 2, date(2018, 1, 11): 1, diff --git a/tests/test_api_documents.py b/tests/test_api_documents.py index 83b0a44..612a219 100644 --- a/tests/test_api_documents.py +++ b/tests/test_api_documents.py @@ -29,9 +29,9 @@ def test_documents(client, index, user): json={ "documents": [{"_id": "id", "title": "a title", "text": "text", "date": "2020-01-01"}], "fields": { - "title": {"elastic_type": "text"}, - "text": {"elastic_type": "text"}, - "date": {"elastic_type": "date"}, + "title": {"type": "text"}, + "text": {"type": "text"}, + "date": {"type": "date"}, }, }, ) diff --git a/tests/test_api_index.py b/tests/test_api_index.py index 0105b04..ed90496 100644 --- a/tests/test_api_index.py +++ b/tests/test_api_index.py @@ -4,7 +4,6 @@ from amcat4.index import get_guest_role, Role, set_guest_role, set_role, remove_role from amcat4.fields import update_fields -from amcat4.models import CreateField, Field, UpdateField from tests.tools import build_headers, post_json, get_json, check, refresh @@ -81,10 +80,10 @@ def test_fields_upload(client: TestClient, user: str, index: str): for i, x in enumerate(["a", "a", "b"]) ], "fields": { - "title": dict(elastic_type="text"), - "text": dict(elastic_type="text"), - "date": dict(elastic_type="date"), - "x": dict(elastic_type="keyword"), + "title": "text", + "text": "text", + "date": "date", + "x": "keyword", }, } diff --git a/tests/test_api_pagination.py b/tests/test_api_pagination.py index 5c1ada1..c6452db 100644 --- a/tests/test_api_pagination.py +++ b/tests/test_api_pagination.py @@ -11,7 +11,7 @@ def test_pagination(client, index, user): # TODO. Tests are not independent. test_pagination fails if run directly after other tests. # Probably delete_index doesn't fully delete - upload(index, docs=[{"i": i} for i in range(66)], fields={"i": CreateField(elastic_type="long")}) + upload(index, docs=[{"i": i} for i in range(66)], fields={"i": CreateField(type="long")}) url = f"/index/{index}/query" r = post_json(client, url, user=user, json={"sort": "i", "per_page": 20, "fields": ["i"]}, expected=200) @@ -28,7 +28,7 @@ def test_pagination(client, index, user): def test_scroll(client, index, user): set_role(index, user, Role.READER) - upload(index, docs=[{"i": i} for i in range(66)], fields={"i": CreateField(elastic_type="long")}) + upload(index, docs=[{"i": i} for i in range(66)], fields={"i": CreateField(type="long")}) url = f"/index/{index}/query" r = post_json( client, diff --git a/tests/test_api_query.py b/tests/test_api_query.py index 11c7b69..7c2ba2f 100644 --- a/tests/test_api_query.py +++ b/tests/test_api_query.py @@ -2,7 +2,7 @@ from amcat4.models import CreateField, FieldSpec from amcat4.query import query_documents from tests.conftest import upload -from tests.tools import get_json, post_json, dictset +from tests.tools import post_json, dictset def test_query_post(client, index_docs, user): @@ -93,9 +93,9 @@ def test_multiple_index(client, index_docs, index, user): index, [{"text": "also a text", "i": -1, "cat": "c"}], fields={ - "text": CreateField(elastic_type="text"), - "cat": CreateField(elastic_type="keyword"), - "i": CreateField(elastic_type="long"), + "text": CreateField(type="text"), + "cat": CreateField(type="keyword"), + "i": CreateField(type="long"), }, ) indices = f"{index},{index_docs}" diff --git a/tests/test_elastic.py b/tests/test_elastic.py index 4b2a6ba..e008636 100644 --- a/tests/test_elastic.py +++ b/tests/test_elastic.py @@ -110,7 +110,7 @@ def test_deduplication(index): def test_identifier_deduplication(index): doc = {"url": "http://", "text": "text"} - upload_documents(index, [doc], fields={"url": CreateField(elastic_type="wildcard", identifier=True), "text": "text"}) + upload_documents(index, [doc], fields={"url": CreateField(type="wildcard", identifier=True), "text": "text"}) refresh_index(index) assert query_documents(index).total_count == 1 diff --git a/tests/test_index.py b/tests/test_index.py index c5c3685..543d02c 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -160,13 +160,13 @@ def test_name_description(index): assert indices[index].name == "test" -def test_summary_field(index): - with pytest.raises(Exception): - modify_index(index, summary_field="doesnotexist") - with pytest.raises(Exception): - modify_index(index, summary_field="title") - update_fields(index, {"party": Field(type="keyword", elastic_type="keyword")}) - modify_index(index, summary_field="party") - assert get_index(index).summary_field == "party" - modify_index(index, summary_field="date") - assert get_index(index).summary_field == "date" +# def test_summary_field(index): +# with pytest.raises(Exception): +# modify_index(index, summary_field="doesnotexist") +# with pytest.raises(Exception): +# modify_index(index, summary_field="title") +# update_fields(index, {"party": Field(type="keyword", type="keyword")}) +# modify_index(index, summary_field="party") +# assert get_index(index).summary_field == "party" +# modify_index(index, summary_field="date") +# assert get_index(index).summary_field == "date" From 1b0148bc8004d94a9f21d3e09d062cd102acc366 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Thu, 28 Mar 2024 17:13:13 +0100 Subject: [PATCH 40/80] upload --- amcat4/api/index.py | 18 ++++++++++++++---- amcat4/fields.py | 37 ++++++++++++++++++------------------- amcat4/index.py | 21 +++++++++++++++++---- 3 files changed, 49 insertions(+), 27 deletions(-) diff --git a/amcat4/api/index.py b/amcat4/api/index.py index 983a25f..9b2e200 100644 --- a/amcat4/api/index.py +++ b/amcat4/api/index.py @@ -1,6 +1,7 @@ """API Endpoints for document and index management.""" from http import HTTPStatus +from re import U from typing import Annotated, Any, Literal import elasticsearch @@ -188,15 +189,24 @@ def upload_documents( ] = None, operation: Annotated[ Literal["index", "update", "create"], - Body(description="The operation to perform. (the default, index, is like upsert)"), - ] = "index", + Body( + description="The operation to perform. Default is create, which ignores any documents that already exist. " + "the 'index' operation basically means 'create or replace', and will completely overwite any existing documents. " + "The 'update' operation behaves as an upsert (create or update). This is less destructive than 'index', " + "because it will only update the fields that are specified in the document. Since index and update are destructive " + "they require admin rights." + ), + ] = "create", user: str = Depends(authenticated_user), ): """ Upload documents to this server. Returns a list of ids for the uploaded documents """ - check_role(user, index.Role.WRITER, ix) - return index.upload_documents(ix, documents, fields, operation) + if operation == "create": + check_role(user, index.Role.WRITER, ix) + else: + check_role(user, index.Role.ADMIN, ix) + return index.upload_documents(ix, documents, fields, operation, return_ids=False) @app_index.get("/{ix}/documents/{docid}") diff --git a/amcat4/fields.py b/amcat4/fields.py index c15bbb3..f571d8f 100644 --- a/amcat4/fields.py +++ b/amcat4/fields.py @@ -120,8 +120,6 @@ def create_fields(index: str, fields: Mapping[str, ElasticType | CreateField]): mapping: dict[str, Any] = {} current_fields = get_fields(index) - new_fields: dict[str, Field] = {} - sfields = _standardize_createfields(fields) for field, settings in sfields.items(): @@ -132,7 +130,7 @@ def create_fields(index: str, fields: Mapping[str, ElasticType | CreateField]): if current is not None: # fields can already exist. For example, a scraper might include the field types in every # upload request. If a field already exists, we'll ignore it, but we will throw an error - # if static settings (field type, identifier) do not match + # if static settings (field type, identifier) do not match. if current.type != settings.type: raise ValueError( f"Field '{field}' already exists with type '{current.type}'. " f"Cannot change type to '{settings.type}'" @@ -141,27 +139,28 @@ def create_fields(index: str, fields: Mapping[str, ElasticType | CreateField]): raise ValueError(f"Field '{field}' is an identifier, cannot change to non-identifier") if not current.identifier and settings.identifier: raise ValueError(f"Field '{field}' is not an identifier, cannot change to identifier") - new_fields[field] = current - else: - # if field does not exist, we add it to both the mapping and the system index - mapping[field] = {"type": settings.type} - if settings.type in ["date"]: - mapping[field]["format"] = "strict_date_optional_time" - - new_fields[field] = Field( - type=settings.type, - type_group=TYPEMAP_ES_TO_AMCAT[settings.type], - identifier=settings.identifier, - metareader=settings.metareader or get_default_metareader(TYPEMAP_ES_TO_AMCAT[settings.type]), - client_settings=settings.client_settings or {}, - ) - print(mapping) + + continue + + # if field does not exist, we add it to both the mapping and the system index + mapping[field] = {"type": settings.type} + if settings.type in ["date"]: + mapping[field]["format"] = "strict_date_optional_time" + + current_fields[field] = Field( + type=settings.type, + type_group=TYPEMAP_ES_TO_AMCAT[settings.type], + identifier=settings.identifier, + metareader=settings.metareader or get_default_metareader(TYPEMAP_ES_TO_AMCAT[settings.type]), + client_settings=settings.client_settings or {}, + ) + if len(mapping) > 0: es().indices.put_mapping(index=index, properties=mapping) es().update( index=get_settings().system_index, id=index, - doc=dict(fields=_fields_to_elastic(new_fields)), + doc=dict(fields=_fields_to_elastic(current_fields)), ) diff --git a/amcat4/index.py b/amcat4/index.py index 3877051..d535373 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -419,7 +419,11 @@ def _get_hash(document: dict, field_settings: dict[str, Field]) -> str: def upload_documents( - index: str, documents: list[dict[str, Any]], fields: Mapping[str, ElasticType | CreateField] | None = None, op_type="index" + index: str, + documents: list[dict[str, Any]], + fields: Mapping[str, ElasticType | CreateField] | None = None, + op_type="index", + return_ids=True, ): """ Upload documents to this index @@ -443,12 +447,21 @@ def es_actions(index, documents, op_type): document[key] = coerce_type(document[key], field_settings[key].type) if "_id" not in document: document["_id"] = _get_hash(document, field_settings) - yield {"_op_type": op_type, "_index": index, **document} + + # https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html + if op_type == "update": + id = document.pop("_id") + yield {"_op_type": op_type, "_index": index, "_id": id, "doc": document, "doc_as_upsert": True} + else: + yield {"_op_type": op_type, "_index": index, **document} actions = list(es_actions(index, documents, op_type)) - ids = [doc["_id"] for doc in actions] successes, failures = elasticsearch.helpers.bulk(es(), actions, stats_only=True, raise_on_error=False) - return dict(ids=ids, successes=successes, failures=failures) + + if return_ids: + ids = [doc["_id"] for doc in actions] + return dict(ids=ids, successes=successes, failures=failures) + return dict(successes=successes, failures=failures) def get_document(index: str, doc_id: str, **kargs) -> dict: From 9d07183dbc85aa9d3bd2e1b48920d54626f14506 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Sun, 31 Mar 2024 18:09:13 +0200 Subject: [PATCH 41/80] bringing back tags --- amcat4/api/index.py | 10 +++++----- amcat4/api/query.py | 8 ++++++++ amcat4/fields.py | 24 ++++++++++++++++++++++++ amcat4/index.py | 4 +++- amcat4/models.py | 1 + 5 files changed, 41 insertions(+), 6 deletions(-) diff --git a/amcat4/api/index.py b/amcat4/api/index.py index 9b2e200..d1cc9ec 100644 --- a/amcat4/api/index.py +++ b/amcat4/api/index.py @@ -188,13 +188,13 @@ def upload_documents( ), ] = None, operation: Annotated[ - Literal["index", "update", "create"], + Literal["update", "create"], Body( description="The operation to perform. Default is create, which ignores any documents that already exist. " - "the 'index' operation basically means 'create or replace', and will completely overwite any existing documents. " - "The 'update' operation behaves as an upsert (create or update). This is less destructive than 'index', " - "because it will only update the fields that are specified in the document. Since index and update are destructive " - "they require admin rights." + "The 'update' operation behaves as an upsert (create or update). If an identical document (or document with " + "identical identifiers) already exists, the uploaded fields will be created or overwritten. If there are fields " + "in the original document that are not in the uploaded document, they will NOT be removed. since update is destructive " + "it requires admin rights." ), ] = "create", user: str = Depends(authenticated_user), diff --git a/amcat4/api/query.py b/amcat4/api/query.py index c806c1f..3de1827 100644 --- a/amcat4/api/query.py +++ b/amcat4/api/query.py @@ -10,6 +10,7 @@ from amcat4 import query, aggregate from amcat4.aggregate import Axis, Aggregation from amcat4.api.auth import authenticated_user, check_fields_access +from amcat4.fields import create_fields from amcat4.index import Role, get_role, get_fields from amcat4.models import FieldSpec, FilterSpec, FilterValue, SortSpec from amcat4.query import update_tag_query @@ -362,6 +363,13 @@ def query_update_tags( """ indices = index.split(",") + for i in indices: + if get_role(i, user) < Role.WRITER: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=f"User {user} does not have permission to update tags on index {i}", + ) + if isinstance(ids, (str, int)): ids = [ids] update_tag_query(indices, action, field, tag, _standardize_queries(queries), _standardize_filters(filters), ids) diff --git a/amcat4/fields.py b/amcat4/fields.py index f571d8f..57f6e55 100644 --- a/amcat4/fields.py +++ b/amcat4/fields.py @@ -255,6 +255,30 @@ def get_fields(index: str) -> dict[str, Field]: return fields +def create_or_verify_tag_field(index: str | list[str], field: str): + """Create a special type of field that can be used to tag documents. + Since adding/removing tags supports multiple indices, we first check whether the field name is valid for all indices""" + indices = [index] if isinstance(index, str) else index + add_to_indices = [] + for i in indices: + current_fields = get_fields(i) + if field in current_fields: + if current_fields[field].tag is False: + raise ValueError(f"Field '{field}' already exists in index '{i}' and is not a tag field") + + else: + add_to_indices.append(i) + + for i in add_to_indices: + current_fields[field] = Field(type="keyword", type_group=TYPEMAP_ES_TO_AMCAT["keyword"], tag=True) + es().indices.put_mapping(index=index, properties={field: {"type": "keyword"}}) + es().update( + index=get_settings().system_index, + id=i, + doc=dict(fields=_fields_to_elastic(current_fields)), + ) + + def field_values(index: str, field: str, size: int) -> list[str]: """ Get the values for a given field (e.g. to populate list of filter values on keyword field) diff --git a/amcat4/index.py b/amcat4/index.py index d535373..e6c45ed 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -48,6 +48,7 @@ from amcat4.fields import ( coerce_type, create_fields, + create_or_verify_tag_field, get_fields, ) from amcat4.models import CreateField, ElasticType, Field @@ -499,7 +500,7 @@ def delete_document(index: str, doc_id: str): def update_by_query(index: str | list[str], script: str, query: dict, params: dict | None = None): script_dict = dict(source=script, lang="painless", params=params or {}) - es().update_by_query(index=index, script=script_dict, **query) + test = es().update_by_query(index=index, script=script_dict, **query) TAG_SCRIPTS = dict( @@ -521,6 +522,7 @@ def update_by_query(index: str | list[str], script: str, query: dict, params: di def update_tag_by_query(index: str | list[str], action: Literal["add", "remove"], query: dict, field: str, tag: str): + create_or_verify_tag_field(index, field) script = TAG_SCRIPTS[action] params = dict(field=field, tag=tag) update_by_query(index, script, query, params) diff --git a/amcat4/models.py b/amcat4/models.py index fd60eb0..f9de2f8 100644 --- a/amcat4/models.py +++ b/amcat4/models.py @@ -57,6 +57,7 @@ class Field(BaseModel): type: ElasticType type_group: TypeGroup identifier: bool = False + tag: bool = False metareader: FieldMetareaderAccess = FieldMetareaderAccess() client_settings: dict[str, Any] = {} From 33bdd38407091b6c613ade7e434845a9c9b6b130 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Mon, 1 Apr 2024 18:55:25 +0200 Subject: [PATCH 42/80] changed field types to enable special amcat fields --- amcat4/api/index.py | 2 +- amcat4/fields.py | 113 ++++++++++++++++++++++++++++-------------- amcat4/index.py | 2 +- amcat4/models.py | 16 +++--- tests/test_elastic.py | 1 + 5 files changed, 88 insertions(+), 46 deletions(-) diff --git a/amcat4/api/index.py b/amcat4/api/index.py index d1cc9ec..36daae9 100644 --- a/amcat4/api/index.py +++ b/amcat4/api/index.py @@ -285,7 +285,7 @@ def create_fields( dict[str, ElasticType | CreateField], Body( description="Either a dictionary that maps field names to field specifications" - "({field: {type: text, identifier:True}}), " + "({field: {type: 'text', identifier: True }}), " "or a simplified version that only specifies the type ({field: type})" ), ], diff --git a/amcat4/fields.py b/amcat4/fields.py index 57f6e55..be8c6c7 100644 --- a/amcat4/fields.py +++ b/amcat4/fields.py @@ -19,16 +19,20 @@ from elasticsearch import NotFoundError +from httpx import get # from amcat4.api.common import py2dict +from amcat4 import elastic from amcat4.config import get_settings from amcat4.elastic import es -from amcat4.models import TypeGroup, CreateField, ElasticType, Field, UpdateField, FieldMetareaderAccess +from amcat4.models import FieldType, CreateField, ElasticType, Field, UpdateField, FieldMetareaderAccess -# given an elastic field type, infer -# (this is relevant if we are importing an index that does not yet have ) -TYPEMAP_ES_TO_AMCAT: dict[ElasticType, TypeGroup] = { +# given an elastic field type, Check if it is supported by AmCAT. +# this is not just the inverse of TYPEMAP_AMCAT_TO_ES because some AmCAT types map to multiple elastic +# types (e.g., tag and keyword, image_url and wildcard) +# (this is relevant if we are importing an index) +TYPEMAP_ES_TO_AMCAT: dict[ElasticType, FieldType] = { # TEXT fields "text": "text", "annotated_text": "text", @@ -42,14 +46,13 @@ "keyword": "keyword", "constant_keyword": "keyword", "wildcard": "keyword", - # NUMBER fields - # - integer + # INTEGER fields "integer": "number", "byte": "number", "short": "number", "long": "number", "unsigned_long": "number", - # - float + # NUMBER fields "float": "number", "half_float": "number", "double": "number", @@ -61,31 +64,52 @@ # VECTOR fields (exclude sparse vectors) "dense_vector": "vector", # GEO fields - "geo_point": "geo", + "geo_point": "geo_point", } +# maps amcat field types to elastic field types. +# The first elastic type in the array is the default. +TYPEMAP_AMCAT_TO_ES: dict[FieldType, list[ElasticType]] = { + "text": ["text", "annotated_text", "binary", "match_only_text"], + "date": ["date"], + "boolean": ["boolean"], + "keyword": ["keyword", "constant_keyword", "wildcard"], + "number": ["double", "float", "half_float", "scaled_float"], + "integer": ["long", "integer", "byte", "short", "unsigned_long"], + "object": ["object", "flattened", "nested"], + "vector": ["dense_vector"], + "geo_point": ["geo_point"], + "tag": ["keyword"], + "image_url": ["wildcard"], +} -def get_default_metareader(type_group: TypeGroup): - if type_group in ["boolean", "number", "date"]: + +def get_default_metareader(type: FieldType): + if type in ["boolean", "number", "date"]: return FieldMetareaderAccess(access="read") return FieldMetareaderAccess(access="none") -def get_default_field(elastic_type: ElasticType): - type_group = TYPEMAP_ES_TO_AMCAT.get(elastic_type) - if type_group is None: - raise ValueError(f"Invalid elastic type: {elastic_type}") - - return Field(type_group=type_group, type=elastic_type, metareader=get_default_metareader(type_group)) +def get_default_field(type: FieldType): + """ + Generate a field on the spot with default settings. + Primary use case is importing existing indices with fields that are not registered in the system index. + """ + elastic_type = TYPEMAP_AMCAT_TO_ES.get(type) + if elastic_type is None: + raise ValueError( + f"The default elastic type mapping for field type {type} is not defined (if this happens, blame and inform Kasper)" + ) + return Field(elastic_type=elastic_type[0], type=type, metareader=get_default_metareader(type)) -def _standardize_createfields(fields: Mapping[str, ElasticType | CreateField]) -> dict[str, CreateField]: +def _standardize_createfields(fields: Mapping[str, FieldType | CreateField]) -> dict[str, CreateField]: sfields = {} for k, v in fields.items(): if isinstance(v, str): assert v in get_args(ElasticType), f"Unknown elastic type {v}" - sfields[k] = CreateField(type=cast(ElasticType, v)) + sfields[k] = CreateField(type=cast(FieldType, v)) else: sfields[k] = v return sfields @@ -116,30 +140,33 @@ def coerce_type(value: Any, elastic_type: ElasticType): return value -def create_fields(index: str, fields: Mapping[str, ElasticType | CreateField]): +def create_fields(index: str, fields: Mapping[str, FieldType | CreateField]): mapping: dict[str, Any] = {} current_fields = get_fields(index) sfields = _standardize_createfields(fields) for field, settings in sfields.items(): - if TYPEMAP_ES_TO_AMCAT.get(settings.type) is None: - raise ValueError(f"Field type {settings.type} not supported by AmCAT") + if settings.elastic_type is not None: + allowed_types = TYPEMAP_AMCAT_TO_ES.get(settings.type, []) + if settings.elastic_type not in allowed_types: + raise ValueError( + f"Field type {settings.type} does not support elastic type {settings.elastic_type}. " + f"Allowed types are: {allowed_types}" + ) + elastic_type = settings.elastic_type + else: + elastic_type = get_default_field(settings.type).elastic_type current = current_fields.get(field) if current is not None: # fields can already exist. For example, a scraper might include the field types in every # upload request. If a field already exists, we'll ignore it, but we will throw an error - # if static settings (field type, identifier) do not match. - if current.type != settings.type: - raise ValueError( - f"Field '{field}' already exists with type '{current.type}'. " f"Cannot change type to '{settings.type}'" - ) - if current.identifier and not settings.identifier: - raise ValueError(f"Field '{field}' is an identifier, cannot change to non-identifier") - if not current.identifier and settings.identifier: - raise ValueError(f"Field '{field}' is not an identifier, cannot change to identifier") - + # if static settings (elastic type, identifier) do not match. + if current.elastic_type != elastic_type: + raise ValueError(f"Field '{field}' already exists with elastic type '{current.elastic_type}'. ") + if current.identifier != settings.identifier: + raise ValueError(f"Field '{field}' already exists with identifier '{current.identifier}'. ") continue # if field does not exist, we add it to both the mapping and the system index @@ -149,9 +176,9 @@ def create_fields(index: str, fields: Mapping[str, ElasticType | CreateField]): current_fields[field] = Field( type=settings.type, - type_group=TYPEMAP_ES_TO_AMCAT[settings.type], + elastic_type=elastic_type, identifier=settings.identifier, - metareader=settings.metareader or get_default_metareader(TYPEMAP_ES_TO_AMCAT[settings.type]), + metareader=settings.metareader or get_default_metareader(settings.type), client_settings=settings.client_settings or {}, ) @@ -188,6 +215,16 @@ def update_fields(index: str, fields: dict[str, UpdateField]): if current is None: raise ValueError(f"Field {field} does not exist") + if new_settings.type is not None: + valid_es_types = TYPEMAP_AMCAT_TO_ES.get(new_settings.type) + if valid_es_types is None: + raise ValueError(f"Invalid field type: {new_settings.type}") + if current.elastic_type not in valid_es_types: + raise ValueError( + f"Field {field} has the elastic type {current.elastic_type}. A {new_settings.type} field can only have the following elastic types: {valid_es_types}." + ) + current_fields[field].type = new_settings.type + if new_settings.metareader is not None: if current.type != "text" and new_settings.metareader.access == "snippet": raise ValueError(f"Field {field} is not of type text, cannot set metareader access to snippet") @@ -232,16 +269,16 @@ def get_fields(index: str) -> dict[str, Field]: update_system_index = False for field, elastic_type in _get_index_fields(index): - type_group = TYPEMAP_ES_TO_AMCAT.get(elastic_type) + type = TYPEMAP_ES_TO_AMCAT.get(elastic_type) - if type_group is None: + if type is None: # skip over unsupported elastic fields. # (TODO: also return warning to client?) continue if field not in system_index_fields: update_system_index = True - fields[field] = get_default_field(elastic_type) + fields[field] = get_default_field(type) else: fields[field] = system_index_fields[field] @@ -263,14 +300,14 @@ def create_or_verify_tag_field(index: str | list[str], field: str): for i in indices: current_fields = get_fields(i) if field in current_fields: - if current_fields[field].tag is False: + if current_fields[field].type != "tag": raise ValueError(f"Field '{field}' already exists in index '{i}' and is not a tag field") else: add_to_indices.append(i) for i in add_to_indices: - current_fields[field] = Field(type="keyword", type_group=TYPEMAP_ES_TO_AMCAT["keyword"], tag=True) + current_fields[field] = get_default_field("tag") es().indices.put_mapping(index=index, properties={field: {"type": "keyword"}}) es().update( index=get_settings().system_index, diff --git a/amcat4/index.py b/amcat4/index.py index e6c45ed..9140525 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -405,7 +405,7 @@ def _get_hash(document: dict, field_settings: dict[str, Field]) -> str: Get the hash for a document """ - identifiers = [k for k, v in field_settings.items() if v.identifier] + identifiers = [k for k, v in field_settings.items() if v.identifier == True] if len(identifiers) == 0: # if no identifiers specified, id is hash of entire document hash_values = document diff --git a/amcat4/models.py b/amcat4/models.py index f9de2f8..4f10477 100644 --- a/amcat4/models.py +++ b/amcat4/models.py @@ -1,9 +1,12 @@ +from xml.dom.domreg import registered import pydantic from pydantic import BaseModel from typing import Annotated, Any, Literal -TypeGroup = Literal["text", "date", "boolean", "keyword", "number", "object", "vector", "geo"] +FieldType = Literal[ + "text", "date", "boolean", "keyword", "number", "integer", "object", "vector", "geo_point", "image_url", "tag" +] ElasticType = Literal[ "text", "annotated_text", @@ -51,13 +54,12 @@ class FieldMetareaderAccess(BaseModel): class Field(BaseModel): - """Settings for a field. Some settings, such as metareader, have a strict model because they are used + """Settings for a field. Some settings, such as metareader, have a strict type because they are used server side. Others, such as client_settings, are free-form and can be used by the client to store settings.""" - type: ElasticType - type_group: TypeGroup + type: FieldType + elastic_type: ElasticType identifier: bool = False - tag: bool = False metareader: FieldMetareaderAccess = FieldMetareaderAccess() client_settings: dict[str, Any] = {} @@ -65,7 +67,8 @@ class Field(BaseModel): class CreateField(BaseModel): """Model for creating a field""" - type: ElasticType + type: FieldType + elastic_type: ElasticType | None = None identifier: bool = False metareader: FieldMetareaderAccess | None = None client_settings: dict[str, Any] | None = None @@ -74,6 +77,7 @@ class CreateField(BaseModel): class UpdateField(BaseModel): """Model for updating a field""" + type: FieldType | None = None metareader: FieldMetareaderAccess | None = None client_settings: dict[str, Any] | None = None diff --git a/tests/test_elastic.py b/tests/test_elastic.py index e008636..dabfe8a 100644 --- a/tests/test_elastic.py +++ b/tests/test_elastic.py @@ -1,4 +1,5 @@ from datetime import datetime +from re import I from amcat4.index import ( refresh_index, From 666464e83af5d51c200b0ee97915adacdba7e99e Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Wed, 3 Apr 2024 18:30:32 +0200 Subject: [PATCH 43/80] only hash ids if too long --- amcat4/__main__.py | 4 ++-- amcat4/fields.py | 40 ++++++++++++++++++++++++++---------- amcat4/index.py | 37 ++++++++++++++++++++++----------- amcat4/models.py | 16 +++++++++++++-- tests/conftest.py | 28 +++++++++++++------------ tests/test_api_documents.py | 3 ++- tests/test_api_pagination.py | 2 +- tests/test_api_query.py | 1 - tests/test_pagination.py | 13 ++++++++---- 9 files changed, 97 insertions(+), 47 deletions(-) diff --git a/amcat4/__main__.py b/amcat4/__main__.py index cd6eaa1..81d1b7b 100644 --- a/amcat4/__main__.py +++ b/amcat4/__main__.py @@ -23,7 +23,7 @@ from amcat4.config import get_settings, AuthOptions, validate_settings from amcat4.elastic import connect_elastic, get_system_version, ping from amcat4.index import GLOBAL_ROLES, create_index, set_global_role, Role, list_global_users, upload_documents -from amcat4.models import ElasticType +from amcat4.models import ElasticType, FieldType SOTU_INDEX = "state_of_the_union" @@ -48,7 +48,7 @@ def upload_test_data() -> str: ) for row in csvfile ] - fields: dict[str, ElasticType] = {"president": "keyword", "party": "keyword", "year": "short"} + fields: dict[str, FieldType] = {"president": "keyword", "party": "keyword", "year": "integer"} upload_documents(SOTU_INDEX, docs, fields) return SOTU_INDEX diff --git a/amcat4/fields.py b/amcat4/fields.py index be8c6c7..d2012bc 100644 --- a/amcat4/fields.py +++ b/amcat4/fields.py @@ -15,6 +15,8 @@ """ from hmac import new +import json +from tabnanny import check from typing import Any, Iterator, Mapping, get_args, cast @@ -79,8 +81,9 @@ "object": ["object", "flattened", "nested"], "vector": ["dense_vector"], "geo_point": ["geo_point"], - "tag": ["keyword"], - "image_url": ["wildcard"], + "tag": ["keyword", "wildcard"], + "image_url": ["wildcard", "keyword", "constant_keyword", "text"], + "json": ["text"], } @@ -115,26 +118,38 @@ def _standardize_createfields(fields: Mapping[str, FieldType | CreateField]) -> return sfields -def coerce_type(value: Any, elastic_type: ElasticType): +def check_forbidden_type(field: Field, type: FieldType): + if field.identifier: + for forbidden_type in ["tag", "vector"]: + if type == forbidden_type: + raise ValueError(f"Field {field} is an identifier field, which cannot be a {forbidden_type} field") + + +def coerce_type(value: Any, type: FieldType): """ Coerces values into the respective type in elastic based on ES_MAPPINGS and elastic field types """ - if elastic_type in ["text", "annotated_text", "binary", "match_only_text", "keyword", "constant_keyword", "wildcard"]: + if type in ["text", "tag", "image_url", "date"]: return str(value) - if elastic_type in ["boolean"]: + if type in ["boolean"]: return bool(value) - if elastic_type in ["long", "integer", "short", "byte", "unsigned_long"]: - return int(value) - if elastic_type in ["float", "half_float", "double", "scaled_float"]: + if type in ["number"]: return float(value) + if type in ["integer"]: + return int(value) + + if type == "json": + if isinstance(value, str): + return value + return json.dumps(value) # TODO: check coercion / validation for object, vector and geo types - if elastic_type in ["object", "flattened", "nested"]: + if type in ["object"]: return value - if elastic_type in ["dense_vector"]: + if type in ["vector"]: return value - if elastic_type in ["geo_point"]: + if type in ["geo_point"]: return value return value @@ -181,6 +196,7 @@ def create_fields(index: str, fields: Mapping[str, FieldType | CreateField]): metareader=settings.metareader or get_default_metareader(settings.type), client_settings=settings.client_settings or {}, ) + check_forbidden_type(current_fields[field], settings.type) if len(mapping) > 0: es().indices.put_mapping(index=index, properties=mapping) @@ -216,6 +232,8 @@ def update_fields(index: str, fields: dict[str, UpdateField]): raise ValueError(f"Field {field} does not exist") if new_settings.type is not None: + check_forbidden_type(current, new_settings.type) + valid_es_types = TYPEMAP_AMCAT_TO_ES.get(new_settings.type) if valid_es_types is None: raise ValueError(f"Invalid field type: {new_settings.type}") diff --git a/amcat4/index.py b/amcat4/index.py index 9140525..6d2ed36 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -51,7 +51,7 @@ create_or_verify_tag_field, get_fields, ) -from amcat4.models import CreateField, ElasticType, Field +from amcat4.models import CreateField, Field, FieldType class Role(IntEnum): @@ -400,18 +400,33 @@ def delete_user(email: str) -> None: set_role(ix.id, email, None) -def _get_hash(document: dict, field_settings: dict[str, Field]) -> str: +def create_id(document: dict, field_settings: dict[str, Field]) -> str: """ - Get the hash for a document + Create the _id for a document. """ identifiers = [k for k, v in field_settings.items() if v.identifier == True] if len(identifiers) == 0: # if no identifiers specified, id is hash of entire document + # (we could also decide that in this case we let elastic create a uuid) hash_values = document else: - # if identifiers specified, id is hash of those fields - hash_values = {k: document.get(k) for k in identifiers if k in document} + # if identifiers specified, we concatenate the values of these fields, and hash them if + # the string exceeds 512 characters (the maximum length of an id in elastic) + # we only use the fields that are present in the document, and sort them alphabetically, + # so that the id is the same after more identifiers are added. + id_fields = sorted(set(identifiers) & set(document.keys())) + + if not id_fields: + raise ValueError(f"None of the identifier fields {identifiers} are present in the document") + if len(id_fields) == 1: + return str(document[id_fields[0]]) + + id = "|".join(f"{k}={str(document[k])}" for k in id_fields) + if len(id.encode("utf-8")) < 500: + return id + + hash_values = {k: document.get(k) for k in id_fields} hash_str = json.dumps(hash_values, sort_keys=True, ensure_ascii=True, default=str).encode("ascii") m = hashlib.sha224() @@ -422,7 +437,7 @@ def _get_hash(document: dict, field_settings: dict[str, Field]) -> str: def upload_documents( index: str, documents: list[dict[str, Any]], - fields: Mapping[str, ElasticType | CreateField] | None = None, + fields: Mapping[str, FieldType | CreateField] | None = None, op_type="index", return_ids=True, ): @@ -439,22 +454,20 @@ def upload_documents( def es_actions(index, documents, op_type): field_settings = get_fields(index) for document in documents: - for key in document.keys(): if key == "_id": - continue + raise ValueError("You cannot directly set the '_id' field in a document.") if key not in field_settings: raise ValueError(f"The type for field '{key}' is not yet specified") document[key] = coerce_type(document[key], field_settings[key].type) - if "_id" not in document: - document["_id"] = _get_hash(document, field_settings) + + id = create_id(document, field_settings) # https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html if op_type == "update": - id = document.pop("_id") yield {"_op_type": op_type, "_index": index, "_id": id, "doc": document, "doc_as_upsert": True} else: - yield {"_op_type": op_type, "_index": index, **document} + yield {"_op_type": op_type, "_index": index, "_id": id, **document} actions = list(es_actions(index, documents, op_type)) successes, failures = elasticsearch.helpers.bulk(es(), actions, stats_only=True, raise_on_error=False) diff --git a/amcat4/models.py b/amcat4/models.py index 4f10477..302f0fa 100644 --- a/amcat4/models.py +++ b/amcat4/models.py @@ -1,11 +1,13 @@ +from curses import OK from xml.dom.domreg import registered import pydantic -from pydantic import BaseModel +from pydantic import BaseModel, field_validator, model_validator, validator from typing import Annotated, Any, Literal +from typing_extensions import Self FieldType = Literal[ - "text", "date", "boolean", "keyword", "number", "integer", "object", "vector", "geo_point", "image_url", "tag" + "text", "date", "boolean", "keyword", "number", "integer", "object", "vector", "geo_point", "image_url", "tag", "json" ] ElasticType = Literal[ "text", @@ -63,6 +65,16 @@ class Field(BaseModel): metareader: FieldMetareaderAccess = FieldMetareaderAccess() client_settings: dict[str, Any] = {} + @model_validator(mode="after") + def validate_type(self) -> Self: + if self.identifier: + # Identifiers have to be immutable. Instead of checking this in every endpoint that performs updates, + # we can disable it for certain types that are known to be mutable. + for forbidden_type in ["tag"]: + if self.type == forbidden_type: + raise ValueError(f"Field type {forbidden_type} cannot be used as an identifier") + return self + class CreateField(BaseModel): """Model for creating a field""" diff --git a/tests/conftest.py b/tests/conftest.py index f3beb29..5a5a8dd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,7 +16,7 @@ set_global_role, upload_documents, ) -from amcat4.models import CreateField, ElasticType +from amcat4.models import CreateField, ElasticType, FieldType from tests.middlecat_keypair import PUBLIC_KEY UNITS = [ @@ -136,7 +136,7 @@ def guest_index(): delete_index(index, ignore_missing=True) -def upload(index: str, docs: list[dict[str, Any]], fields: dict[str, ElasticType | CreateField] | None = None): +def upload(index: str, docs: list[dict[str, Any]], fields: dict[str, FieldType | CreateField] | None = None): """ Upload these docs to the index, giving them an incremental id, and flush """ @@ -146,10 +146,10 @@ def upload(index: str, docs: list[dict[str, Any]], fields: dict[str, ElasticType TEST_DOCUMENTS = [ - {"_id": 0, "cat": "a", "subcat": "x", "i": 1, "date": "2018-01-01", "text": "this is a text", "title": "title"}, - {"_id": 1, "cat": "a", "subcat": "x", "i": 2, "date": "2018-02-01", "text": "a test text", "title": "title"}, + {"id": 0, "cat": "a", "subcat": "x", "i": 1, "date": "2018-01-01", "text": "this is a text", "title": "title"}, + {"id": 1, "cat": "a", "subcat": "x", "i": 2, "date": "2018-02-01", "text": "a test text", "title": "title"}, { - "_id": 2, + "id": 2, "cat": "a", "subcat": "y", "i": 11, @@ -158,7 +158,7 @@ def upload(index: str, docs: list[dict[str, Any]], fields: dict[str, ElasticType "title": "bla", }, { - "_id": 3, + "id": 3, "cat": "b", "subcat": "y", "i": 31, @@ -170,16 +170,18 @@ def upload(index: str, docs: list[dict[str, Any]], fields: dict[str, ElasticType def populate_index(index): + upload( index, TEST_DOCUMENTS, fields={ - "text": CreateField(type="text"), - "title": CreateField(type="text"), - "date": CreateField(type="date"), - "cat": CreateField(type="keyword"), - "subcat": CreateField(type="keyword"), - "i": CreateField(type="long"), + "id": CreateField(type="integer", identifier=True), + "text": "text", + "title": "text", + "date": "date", + "cat": "keyword", + "subcat": "keyword", + "i": "integer", }, ) return TEST_DOCUMENTS @@ -203,7 +205,7 @@ def index_many(): upload( index, [dict(id=i, pagenr=abs(10 - i), text=text) for (i, text) in enumerate(["odd", "even"] * 10)], - fields={"id": CreateField(type="long"), "pagenr": CreateField(type="long")}, + fields={"id": "integer", "pagenr": "integer", "text": "text"}, ) yield index delete_index(index, ignore_missing=True) diff --git a/tests/test_api_documents.py b/tests/test_api_documents.py index 612a219..feff654 100644 --- a/tests/test_api_documents.py +++ b/tests/test_api_documents.py @@ -27,8 +27,9 @@ def test_documents(client, index, user): f"index/{index}/documents", user=user, json={ - "documents": [{"_id": "id", "title": "a title", "text": "text", "date": "2020-01-01"}], + "documents": [{"id": "id", "title": "a title", "text": "text", "date": "2020-01-01"}], "fields": { + "id": {"type": "keyword", "identifier": True}, "title": {"type": "text"}, "text": {"type": "text"}, "date": {"type": "date"}, diff --git a/tests/test_api_pagination.py b/tests/test_api_pagination.py index c6452db..4151c60 100644 --- a/tests/test_api_pagination.py +++ b/tests/test_api_pagination.py @@ -11,7 +11,7 @@ def test_pagination(client, index, user): # TODO. Tests are not independent. test_pagination fails if run directly after other tests. # Probably delete_index doesn't fully delete - upload(index, docs=[{"i": i} for i in range(66)], fields={"i": CreateField(type="long")}) + upload(index, docs=[{"i": i} for i in range(66)], fields={"i": "integer"}) url = f"/index/{index}/query" r = post_json(client, url, user=user, json={"sort": "i", "per_page": 20, "fields": ["i"]}, expected=200) diff --git a/tests/test_api_query.py b/tests/test_api_query.py index 7c2ba2f..939fd08 100644 --- a/tests/test_api_query.py +++ b/tests/test_api_query.py @@ -116,7 +116,6 @@ def test_multiple_index(client, index_docs, index, user): json={"axes": [{"field": "cat"}], "fields": ["_id"]}, expected=200, ) - print(r) assert dictset(r["data"]) == dictset([{"cat": "a", "n": 3}, {"n": 1, "cat": "b"}, {"n": 1, "cat": "c"}]) diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 4f63a66..60d5821 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -18,11 +18,16 @@ def test_pagination(index_many): def test_sort(index_many): def q(key, per_page=5) -> List[int]: - res = query_documents(index_many, per_page=per_page, sort=key) - return [int(h["_id"]) for h in res.data] - assert q("id") == [0, 1, 2, 3, 4] - assert q("pagenr") == [10, 9, 11, 8, 12] + for i, k in enumerate(key): + if isinstance(k, str): + key[i] = {k: {"order": "asc"}} + res = query_documents(index_many, per_page=per_page, fields=[FieldSpec(name="id")], sort=key) + print(list(res.data)) + return [int(h["id"]) for h in res.data] + + assert q(["id"]) == [0, 1, 2, 3, 4] + assert q(["pagenr"]) == [10, 9, 11, 8, 12] assert q(["pagenr", "id"]) == [10, 9, 11, 8, 12] assert q([{"pagenr": {"order": "desc"}}, "id"]) == [0, 1, 19, 2, 18] From a65588c9dc6349ac28f83fa8372ef7984fc76b3e Mon Sep 17 00:00:00 2001 From: Wouter van Atteveldt Date: Sat, 6 Apr 2024 14:15:16 +1000 Subject: [PATCH 44/80] Always allow fields if no_auth --- amcat4/api/query.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/amcat4/api/query.py b/amcat4/api/query.py index 3de1827..1a37102 100644 --- a/amcat4/api/query.py +++ b/amcat4/api/query.py @@ -10,6 +10,7 @@ from amcat4 import query, aggregate from amcat4.aggregate import Axis, Aggregation from amcat4.api.auth import authenticated_user, check_fields_access +from amcat4.config import AuthOptions, get_settings from amcat4.fields import create_fields from amcat4.index import Role, get_role, get_fields from amcat4.models import FieldSpec, FilterSpec, FilterValue, SortSpec @@ -43,7 +44,6 @@ def get_or_validate_allowed_fields( they are allowed to see. If fields is None, return all allowed fields. If fields is not None, check whether the user can access the fields (If not, raise an error). """ - if not isinstance(user, str): raise ValueError("User should be a string") if not isinstance(indices, list): @@ -51,6 +51,7 @@ def get_or_validate_allowed_fields( if fields is not None and not isinstance(fields, list): raise ValueError("Fields should be a list or None") + no_auth = get_settings().auth == AuthOptions.no_auth if fields is None: if len(indices) > 1: # this restrictions is needed, because otherwise we need to return all allowed fields taking @@ -62,7 +63,7 @@ def get_or_validate_allowed_fields( role = get_role(indices[0], user) allowed_fields: list[FieldSpec] = [] for field in index_fields.keys(): - if role >= Role.READER: + if role >= Role.READER or no_auth: allowed_fields.append(FieldSpec(name=field)) elif role == Role.METAREADER: metareader = index_fields[field].metareader @@ -78,7 +79,8 @@ def get_or_validate_allowed_fields( return allowed_fields for index in indices: - check_fields_access(index, user, fields) + if not no_auth: + check_fields_access(index, user, fields) return fields @@ -234,7 +236,6 @@ def query_documents_post( """ indices = index.split(",") fieldspecs = get_or_validate_allowed_fields(user, indices, _standardize_fieldspecs(fields)) - r = query.query_documents( indices, queries=_standardize_queries(queries), From 0fe7ffd266f71620e72384753e9aac7df065f9bb Mon Sep 17 00:00:00 2001 From: Wouter van Atteveldt Date: Sat, 6 Apr 2024 16:29:53 +1000 Subject: [PATCH 45/80] Don't create _admin user entry for new indexes in no_auth --- amcat4/api/index.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/amcat4/api/index.py b/amcat4/api/index.py index 36daae9..acb1f65 100644 --- a/amcat4/api/index.py +++ b/amcat4/api/index.py @@ -69,7 +69,7 @@ def create_index(new_index: NewIndex, current_user: str = Depends(authenticated_ guest_role=guest_role, name=new_index.name, description=new_index.description, - admin=current_user, + admin=current_user if current_user != "_admin" else None, ) except ApiError as e: raise HTTPException( From 93e213c228c0bb04c2c8eda8b5eb84dc20a57194 Mon Sep 17 00:00:00 2001 From: Wouter van Atteveldt Date: Mon, 8 Apr 2024 12:49:45 +1000 Subject: [PATCH 46/80] mapping should use elastic type --- amcat4/fields.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/amcat4/fields.py b/amcat4/fields.py index d2012bc..13526ed 100644 --- a/amcat4/fields.py +++ b/amcat4/fields.py @@ -1,9 +1,9 @@ """ We have two types of fields: -- Elastic fields are the fields used under the hood by elastic. +- Elastic fields are the fields used under the hood by elastic. (https://www.elastic.co/guide/en/elasticsearch/reference/current/mapping-types.html These are stored in the Mapping of an index -- Amcat fields (Field) are the fields are seen by the amcat user. They use a simplified type, and contain additional +- Amcat fields (Field) are the fields are seen by the amcat user. They use a simplified type, and contain additional information such as metareader access These are stored in the system index. @@ -185,7 +185,7 @@ def create_fields(index: str, fields: Mapping[str, FieldType | CreateField]): continue # if field does not exist, we add it to both the mapping and the system index - mapping[field] = {"type": settings.type} + mapping[field] = {"type": elastic_type} if settings.type in ["date"]: mapping[field]["format"] = "strict_date_optional_time" From 48747f5c7db47eb8687dc6e68c6861831ddc95f0 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Mon, 8 Apr 2024 14:16:14 +0200 Subject: [PATCH 47/80] fixed bug in confusing fieldtype with elastictype --- amcat4/api/index.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/amcat4/api/index.py b/amcat4/api/index.py index acb1f65..db4664a 100644 --- a/amcat4/api/index.py +++ b/amcat4/api/index.py @@ -15,7 +15,7 @@ from amcat4.index import refresh_system_index, remove_role, set_role from amcat4.fields import field_values, field_stats -from amcat4.models import CreateField, ElasticType, UpdateField +from amcat4.models import CreateField, ElasticType, FieldType, UpdateField app_index = APIRouter(prefix="/index", tags=["index"]) @@ -282,7 +282,7 @@ def delete_document(ix: str, docid: str, user: str = Depends(authenticated_user) def create_fields( ix: str, fields: Annotated[ - dict[str, ElasticType | CreateField], + dict[str, FieldType | CreateField], Body( description="Either a dictionary that maps field names to field specifications" "({field: {type: 'text', identifier: True }}), " From 550bd2cd2b3ca5b55cc31bacd7b2dbc44dc82f90 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Mon, 8 Apr 2024 17:01:33 +0200 Subject: [PATCH 48/80] added refresh to update by query, so that tags update immediately --- amcat4/api/query.py | 1 - amcat4/index.py | 10 +++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/amcat4/api/query.py b/amcat4/api/query.py index 1a37102..ffb3189 100644 --- a/amcat4/api/query.py +++ b/amcat4/api/query.py @@ -363,7 +363,6 @@ def query_update_tags( Add or remove tags by query or by id """ indices = index.split(",") - for i in indices: if get_role(i, user) < Role.WRITER: raise HTTPException( diff --git a/amcat4/index.py b/amcat4/index.py index 6d2ed36..8861643 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -470,7 +470,11 @@ def es_actions(index, documents, op_type): yield {"_op_type": op_type, "_index": index, "_id": id, **document} actions = list(es_actions(index, documents, op_type)) - successes, failures = elasticsearch.helpers.bulk(es(), actions, stats_only=True, raise_on_error=False) + successes, failures = elasticsearch.helpers.bulk( + es(), + actions, + stats_only=True, + ) if return_ids: ids = [doc["_id"] for doc in actions] @@ -505,7 +509,7 @@ def delete_document(index: str, doc_id: str): """ Delete a single document - :param index: The name of the index + :param index: The Pname of the index :param doc_id: The document id (hash) """ es().delete(index=index, id=doc_id) @@ -513,7 +517,7 @@ def delete_document(index: str, doc_id: str): def update_by_query(index: str | list[str], script: str, query: dict, params: dict | None = None): script_dict = dict(source=script, lang="painless", params=params or {}) - test = es().update_by_query(index=index, script=script_dict, **query) + test = es().update_by_query(index=index, script=script_dict, **query, refresh=True) TAG_SCRIPTS = dict( From bf794ec1925d710b10e793633200d24a1e1cc3c0 Mon Sep 17 00:00:00 2001 From: Wouter van Atteveldt Date: Tue, 9 Apr 2024 10:40:40 +1000 Subject: [PATCH 49/80] Fixed query tests --- tests/test_api_query.py | 21 ++++++++++++--------- tests/test_query.py | 2 +- tests/tools.py | 4 +++- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/tests/test_api_query.py b/tests/test_api_query.py index 939fd08..dff84dd 100644 --- a/tests/test_api_query.py +++ b/tests/test_api_query.py @@ -2,7 +2,7 @@ from amcat4.models import CreateField, FieldSpec from amcat4.query import query_documents from tests.conftest import upload -from tests.tools import post_json, dictset +from tests.tools import build_headers, check, post_json, dictset def test_query_post(client, index_docs, user): @@ -28,7 +28,8 @@ def qi(**query_string): assert qi(filters={"cat": {"values": ["a"]}}) == {0, 1, 2} # Can we request specific fields? - all_fields = {"_id", "cat", "subcat", "i", "date", "text", "title"} + all_fields = {"_id", "id", "cat", "subcat", "i", "date", "text", "title"} + print(q()[0]) assert set(q()[0].keys()) == all_fields assert set(q(fields=["cat"])[0].keys()) == {"_id", "cat"} assert set(q(fields=["date", "title"])[0].keys()) == {"_id", "date", "title"} @@ -58,7 +59,7 @@ def test_aggregate(client, index_docs, user): }, ) assert dictset(r["data"]) == dictset([{"avg_i": 1.5, "n": 2, "subcat": "x"}, {"avg_i": 21.0, "n": 2, "subcat": "y"}]) - assert r["meta"]["aggregations"] == [{"field": "i", "function": "avg", "type": "number", "name": "avg_i"}] + assert r["meta"]["aggregations"] == [{"field": "i", "function": "avg", "type": "integer", "name": "avg_i"}] # test filtered aggregate r = post_json( @@ -95,7 +96,7 @@ def test_multiple_index(client, index_docs, index, user): fields={ "text": CreateField(type="text"), "cat": CreateField(type="keyword"), - "i": CreateField(type="long"), + "i": CreateField(type="integer"), }, ) indices = f"{index},{index_docs}" @@ -150,11 +151,13 @@ def test_aggregate_datemappings(client, index_docs, user): def test_query_tags(client, index_docs, user): def tags(): - return { - doc["_id"]: doc["tag"] - for doc in query_documents(index_docs, fields=[FieldSpec(name="tag")]).data - if doc.get("tag") - } + result = query_documents(index_docs, fields=[FieldSpec(name="tag")]) + return {doc["_id"]: doc["tag"] for doc in (result.data if result else []) if doc.get("tag")} + + check(client.post(f"/index/{index_docs}/tags_update"), 401) + check(client.post(f"/index/{index_docs}/tags_update", headers=build_headers(user=user)), 401) + + set_role(index_docs, user, Role.WRITER) assert tags() == {} post_json( diff --git a/tests/test_query.py b/tests/test_query.py index b257348..f528027 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -89,7 +89,7 @@ def test_highlight(index): def test_query_multiple_index(index_docs, index): - upload(index, [{"text": "also a text", "i": -1}], fields={"i": "long", "text": "text"}) + upload(index, [{"text": "also a text", "i": -1}], fields={"i": "integer", "text": "text"}) docs = query.query_documents([index_docs, index]) assert docs is not None assert len(docs.data) == 5 diff --git a/tests/tools.py b/tests/tools.py index 1532f5c..e2651af 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -44,7 +44,9 @@ def post_json(client: TestClient, url, expected=201, headers=None, user=None, ** assert response.status_code == expected, ( f"POST {url} returned {response.status_code}, expected {expected}\n" f"{response.json()}" ) - if expected != 204: + if expected == 204: + return {} + else: return response.json() From cc139ecf843798c6b92d49095fb6510be7c9adeb Mon Sep 17 00:00:00 2001 From: Wouter van Atteveldt Date: Tue, 9 Apr 2024 10:54:41 +1000 Subject: [PATCH 50/80] Allow aggregate without either axes or aggregation, giving n only --- amcat4/api/query.py | 5 ----- tests/test_aggregate.py | 3 +++ tests/test_api_query.py | 30 ++++++++++++++++++++++++++++++ tests/test_query.py | 10 +++++++--- 4 files changed, 40 insertions(+), 8 deletions(-) diff --git a/amcat4/api/query.py b/amcat4/api/query.py index ffb3189..9745704 100644 --- a/amcat4/api/query.py +++ b/amcat4/api/query.py @@ -307,11 +307,6 @@ def query_aggregate_post( indices = index.split(",") _axes = [Axis(**x.model_dump()) for x in axes] if axes else [] _aggregations = [Aggregation(**x.model_dump()) for x in aggregations] if aggregations else [] - if not (_axes or _aggregations): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Aggregation needs at least one axis or aggregation", - ) results = aggregate.query_aggregate( indices, diff --git a/tests/test_aggregate.py b/tests/test_aggregate.py index c1e41cc..7ec3ee8 100644 --- a/tests/test_aggregate.py +++ b/tests/test_aggregate.py @@ -92,6 +92,9 @@ def q(axes, aggregations): assert q(None, [Aggregation("i", "avg")]) == dictset([{"n": 4, "avg_i": 11.25}]) assert q(None, [Aggregation("i", "avg"), Aggregation("i", "max")]) == dictset([{"n": 4, "avg_i": 11.25, "max_i": 31.0}]) + # Count only + assert q([], []) == dictset([{"n": 4}]) + # Check value handling - Aggregation on date fields assert q(None, [Aggregation("date", "max")]) == dictset([{"n": 4, "max_date": "2020-01-01T00:00:00"}]) assert q([Axis("subcat")], [Aggregation("date", "avg")]) == dictset( diff --git a/tests/test_api_query.py b/tests/test_api_query.py index dff84dd..100c539 100644 --- a/tests/test_api_query.py +++ b/tests/test_api_query.py @@ -88,6 +88,36 @@ def test_aggregate(client, index_docs, user): assert data == {"x": 2} +def test_bare_aggregate(client, index_docs, user): + r = post_json( + client, + f"/index/{index_docs}/aggregate", + user=user, + expected=200, + json={}, + ) + assert r["meta"]["axes"] == [] + assert r["data"] == [dict(n=4)] + + r = post_json( + client, + f"/index/{index_docs}/aggregate", + user=user, + expected=200, + json={"aggregations": [{"field": "i", "function": "avg"}]}, + ) + assert r["data"] == [dict(n=4, avg_i=11.25)] + + r = post_json( + client, + f"/index/{index_docs}/aggregate", + user=user, + expected=200, + json={"aggregations": [{"field": "i", "function": "min", "name": "mini"}]}, + ) + assert r["data"] == [dict(n=4, mini=1)] + + def test_multiple_index(client, index_docs, index, user): set_role(index, user, Role.READER) upload( diff --git a/tests/test_query.py b/tests/test_query.py index f528027..87533a8 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -40,11 +40,13 @@ def test_query(index_docs): def test_snippet(index_docs): docs = query.query_documents(index_docs, fields=[FieldSpec(name="text", snippet=SnippetParams(nomatch_chars=5))]) + assert docs is not None assert docs.data[0]["text"] == "this is" docs = query.query_documents( index_docs, queries={"1": "a"}, fields=[FieldSpec(name="text", snippet=SnippetParams(max_matches=1, match_chars=1))] ) + assert docs is not None assert docs.data[0]["text"] == "a" @@ -74,7 +76,7 @@ def test_highlight(index): assert doc["title"] == "Een test titel" assert doc["text"] == f"{words} a test document. {words} other text documents. {words} you!" - doc = query.query_documents( + res = query.query_documents( index, queries={"1": "te*"}, fields=[ @@ -82,8 +84,10 @@ def test_highlight(index): FieldSpec(name="text", snippet=SnippetParams(max_matches=3, match_chars=50)), ], highlight=True, - ).data[0] - assert doc["title"] == "Een test titel" + ) + assert res is not None + doc = res.data[0] + assert doc["title"] == "Een test titel" assert " a test" in doc["text"] assert " ... " in doc["text"] From 265e56f986b2c43f8da4ce2aa1db94a320d1fe27 Mon Sep 17 00:00:00 2001 From: Wouter van Atteveldt Date: Tue, 9 Apr 2024 12:09:37 +1000 Subject: [PATCH 51/80] Add url field, fix date serialization, fix various tests --- amcat4/fields.py | 6 +++++- amcat4/index.py | 25 ++++++++++++++++++------- amcat4/models.py | 14 +++++++++++++- tests/test_api_pagination.py | 2 +- tests/test_elastic.py | 35 +++++++++++++++++------------------ tests/test_pagination.py | 8 ++++++++ 6 files changed, 62 insertions(+), 28 deletions(-) diff --git a/amcat4/fields.py b/amcat4/fields.py index 13526ed..28f9ea1 100644 --- a/amcat4/fields.py +++ b/amcat4/fields.py @@ -14,6 +14,7 @@ system index """ +import datetime from hmac import new import json from tabnanny import check @@ -83,6 +84,7 @@ "geo_point": ["geo_point"], "tag": ["keyword", "wildcard"], "image_url": ["wildcard", "keyword", "constant_keyword", "text"], + "url": ["wildcard", "keyword", "constant_keyword", "text"], "json": ["text"], } @@ -111,7 +113,7 @@ def _standardize_createfields(fields: Mapping[str, FieldType | CreateField]) -> sfields = {} for k, v in fields.items(): if isinstance(v, str): - assert v in get_args(ElasticType), f"Unknown elastic type {v}" + assert v in get_args(FieldType), f"Unknown amcat type {v}" sfields[k] = CreateField(type=cast(FieldType, v)) else: sfields[k] = v @@ -130,6 +132,8 @@ def coerce_type(value: Any, type: FieldType): Coerces values into the respective type in elastic based on ES_MAPPINGS and elastic field types """ + if type == "date" and isinstance(value, datetime.date): + return value.isoformat() if type in ["text", "tag", "image_url", "date"]: return str(value) if type in ["boolean"]: diff --git a/amcat4/index.py b/amcat4/index.py index 8861643..6a6799b 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -34,6 +34,7 @@ import collections from enum import IntEnum +import logging from typing import Any, Iterable, Mapping, Optional, Literal import hashlib @@ -438,8 +439,9 @@ def upload_documents( index: str, documents: list[dict[str, Any]], fields: Mapping[str, FieldType | CreateField] | None = None, - op_type="index", + op_type: Literal["index", "update"] = "index", return_ids=True, + raise_on_error=True, ): """ Upload documents to this index @@ -447,6 +449,7 @@ def upload_documents( :param index: The name of the index (without prefix) :param documents: A sequence of article dictionaries :param fields: A mapping of fieldname:UpdateField for field types + :param op_type: Whether to 'index' new documents (default) or 'update' existing documents """ if fields: create_fields(index, fields) @@ -462,7 +465,6 @@ def es_actions(index, documents, op_type): document[key] = coerce_type(document[key], field_settings[key].type) id = create_id(document, field_settings) - # https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html if op_type == "update": yield {"_op_type": op_type, "_index": index, "_id": id, "doc": document, "doc_as_upsert": True} @@ -470,11 +472,20 @@ def es_actions(index, documents, op_type): yield {"_op_type": op_type, "_index": index, "_id": id, **document} actions = list(es_actions(index, documents, op_type)) - successes, failures = elasticsearch.helpers.bulk( - es(), - actions, - stats_only=True, - ) + try: + successes, failures = elasticsearch.helpers.bulk( + es(), + actions, + stats_only=True, + raise_on_error=raise_on_error, + ) + except elasticsearch.helpers.BulkIndexError as e: + logging.error("Error on indexing: " + json.dumps(e.errors, indent=2, default=str)) + if e.errors: + _, error = list(e.errors[0].items())[0] + reason = error.get("error", {}).get("reason", error) + e.args = e.args + (f"First error: {reason}",) + raise if return_ids: ids = [doc["_id"] for doc in actions] diff --git a/amcat4/models.py b/amcat4/models.py index 302f0fa..c04ae04 100644 --- a/amcat4/models.py +++ b/amcat4/models.py @@ -7,7 +7,19 @@ FieldType = Literal[ - "text", "date", "boolean", "keyword", "number", "integer", "object", "vector", "geo_point", "image_url", "tag", "json" + "text", + "date", + "boolean", + "keyword", + "number", + "integer", + "object", + "vector", + "geo_point", + "url", + "image_url", + "tag", + "json", ] ElasticType = Literal[ "text", diff --git a/tests/test_api_pagination.py b/tests/test_api_pagination.py index 4151c60..747a1c8 100644 --- a/tests/test_api_pagination.py +++ b/tests/test_api_pagination.py @@ -28,7 +28,7 @@ def test_pagination(client, index, user): def test_scroll(client, index, user): set_role(index, user, Role.READER) - upload(index, docs=[{"i": i} for i in range(66)], fields={"i": CreateField(type="long")}) + upload(index, docs=[{"i": i} for i in range(66)], fields={"i": CreateField(type="integer")}) url = f"/index/{index}/query" r = post_json( client, diff --git a/tests/test_elastic.py b/tests/test_elastic.py index dabfe8a..8b31268 100644 --- a/tests/test_elastic.py +++ b/tests/test_elastic.py @@ -23,7 +23,7 @@ def test_upload_retrieve_document(index): _id="test", term_tfidf=[{"term": "test", "value": 0.2}, {"term": "value", "value": 0.3}], ) - upload_documents(index, [a], fields={"text": "text", "title": "text", "date": "date", "term_tfidf": "nested"}) + upload_documents(index, [a], fields={"text": "text", "title": "text", "date": "date", "term_tfidf": "object"}) d = get_document(index, "test") assert d["title"] == a["title"] assert d["term_tfidf"] == a["term_tfidf"] @@ -32,7 +32,7 @@ def test_upload_retrieve_document(index): def test_data_coerced(index): """Are field values coerced to the correct field type""" - create_fields(index, {"i": "long", "x": "double", "title": "text", "date": "date", "text": "text"}) + create_fields(index, {"i": "integer", "x": "number", "title": "text", "date": "date", "text": "text"}) a = dict(_id="DoccyMcDocface", text="text", title="test-numeric", date="2022-12-13", i="1", x="1.1") upload_documents(index, [a]) d = get_document(index, "DoccyMcDocface") @@ -45,7 +45,7 @@ def test_data_coerced(index): def test_fields(index): """Can we get the fields from an index""" - create_fields(index, {"title": "text", "date": "date", "text": "text", "url": "keyword"}) + create_fields(index, {"title": "text", "date": "date", "text": "text", "url": "url"}) fields = get_fields(index) assert set(fields.keys()) == {"title", "date", "text", "url"} assert fields["title"].type == "text" @@ -78,11 +78,8 @@ def q(*ids): return dict(query=dict(ids={"values": ids})) def tags(): - return { - doc["_id"]: doc["tag"] - for doc in query_documents(index_docs, fields=[FieldSpec(name="tag")]).data - if "tag" in doc and doc["tag"] is not None - } + res = query_documents(index_docs, fields=[FieldSpec(name="tag")]) + return {doc["_id"]: doc["tag"] for doc in (res.data if res else []) if "tag" in doc and doc["tag"] is not None} assert tags() == {} update_tag_by_query(index_docs, "add", q("0", "1"), "tag", "x") @@ -102,25 +99,27 @@ def tags(): def test_deduplication(index): doc = {"title": "titel", "text": "text", "date": datetime(2020, 1, 1)} upload_documents(index, [doc], fields={"title": "text", "text": "text", "date": "date"}) - refresh_index(index) - assert query_documents(index).total_count == 1 + _assert_n(index, 1) upload_documents(index, [doc]) - refresh_index(index) - assert query_documents(index).total_count == 1 + _assert_n(index, 1) def test_identifier_deduplication(index): doc = {"url": "http://", "text": "text"} - upload_documents(index, [doc], fields={"url": CreateField(type="wildcard", identifier=True), "text": "text"}) - refresh_index(index) - assert query_documents(index).total_count == 1 + upload_documents(index, [doc], fields={"url": CreateField(type="keyword", identifier=True), "text": "text"}) + _assert_n(index, 1) doc2 = {"url": "http://", "text": "text2"} upload_documents(index, [doc2]) - refresh_index(index) - assert query_documents(index).total_count == 1 + _assert_n(index, 1) doc3 = {"url": "http://2", "text": "text"} upload_documents(index, [doc3]) + _assert_n(index, 2) + + +def _assert_n(index, n): refresh_index(index) - assert query_documents(index).total_count == 2 + res = query_documents(index) + assert res is not None + assert res.total_count == n diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 60d5821..26cad28 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -5,11 +5,13 @@ def test_pagination(index_many): x = query_documents(index_many, per_page=6) + assert x is not None assert x.page_count == 4 assert x.per_page == 6 assert len(x.data) == 6 assert x.page == 0 x = query_documents(index_many, per_page=6, page=3) + assert x is not None assert x.page_count == 4 assert x.per_page == 6 assert len(x.data) == 20 - 3 * 6 @@ -23,6 +25,8 @@ def q(key, per_page=5) -> List[int]: if isinstance(k, str): key[i] = {k: {"order": "asc"}} res = query_documents(index_many, per_page=per_page, fields=[FieldSpec(name="id")], sort=key) + assert res is not None + print(list(res.data)) return [int(h["id"]) for h in res.data] @@ -34,20 +38,24 @@ def q(key, per_page=5) -> List[int]: def test_scroll(index_many): r = query_documents(index_many, queries={"odd": "odd"}, scroll="5m", per_page=4, fields=[FieldSpec(name="id")]) + assert r is not None assert len(r.data) == 4 assert r.total_count, 10 assert r.page_count == 3 allids = list(r.data) r = query_documents(index_many, scroll_id=r.scroll_id, fields=[FieldSpec(name="id")]) + assert r is not None assert len(r.data) == 4 allids += r.data r = query_documents(index_many, scroll_id=r.scroll_id, fields=[FieldSpec(name="id")]) + assert r is not None assert len(r.data) == 2 allids += r.data r = query_documents(index_many, scroll_id=r.scroll_id, fields=[FieldSpec(name="id")]) + assert r is not None assert len(r.data) == 0 assert {int(h["id"]) for h in allids} == {0, 2, 4, 6, 8, 10, 12, 14, 16, 18} From 24b371cf908105e1dfec22b18019ff060123a91e Mon Sep 17 00:00:00 2001 From: Wouter van Atteveldt Date: Tue, 9 Apr 2024 12:51:35 +1000 Subject: [PATCH 52/80] Fixed final tests --- amcat4/index.py | 1 + tests/test_api_errors.py | 12 ++++++++---- tests/test_pagination.py | 3 +-- tests/test_query.py | 2 +- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/amcat4/index.py b/amcat4/index.py index 6a6799b..b81e3fe 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -459,6 +459,7 @@ def es_actions(index, documents, op_type): for document in documents: for key in document.keys(): if key == "_id": + # WvA: Do we really wish to disallow this? raise ValueError("You cannot directly set the '_id' field in a document.") if key not in field_settings: raise ValueError(f"The type for field '{key}' is not yet specified") diff --git a/tests/test_api_errors.py b/tests/test_api_errors.py index 5fac4b9..063d4af 100644 --- a/tests/test_api_errors.py +++ b/tests/test_api_errors.py @@ -12,16 +12,20 @@ def check(client, url, status, message, method="post", user=None, **kargs): raise AssertionError(f"Status {r.status_code} error {repr(r.text)} does not match pattern {repr(message)}") -def test_documents_unauthorized(client, index, writer, ): +def test_documents_unauthorized( + client, + index, + writer, +): check(client, "/index/", 401, "global writer permissions") - check(client, f"/index/{index}/", 401, f"permissions on index {index}", method='get') + check(client, f"/index/{index}/", 401, f"permissions on index {index}", method="get") def test_error_elastic(client, index, admin): for hostname in ("doesnotexist.example.com", "https://doesnotexist.example.com:9200"): - with amcat_settings(elastic_host=hostname): + with amcat_settings(elastic_host=hostname, elastic_verify_ssl=True): es.cache_clear() - check(client, f"/index/{index}/", 500, f"cannot connect.*{hostname}", method='get', user=admin) + check(client, f"/index/{index}/", 500, f"cannot connect.*{hostname}", method="get", user=admin) def test_error_index_create(client, writer, index): diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 26cad28..acfe0db 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -55,7 +55,6 @@ def test_scroll(index_many): allids += r.data r = query_documents(index_many, scroll_id=r.scroll_id, fields=[FieldSpec(name="id")]) - assert r is not None + assert r is None - assert len(r.data) == 0 assert {int(h["id"]) for h in allids} == {0, 2, 4, 6, 8, 10, 12, 14, 16, 18} diff --git a/tests/test_query.py b/tests/test_query.py index 87533a8..ac8bfa4 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -87,7 +87,7 @@ def test_highlight(index): ) assert res is not None doc = res.data[0] - assert doc["title"] == "Een test titel" + assert doc["title"] == "Een test titel" assert " a test" in doc["text"] assert " ... " in doc["text"] From 6ca864de8ca224d46e2f10d4117e5608755c969e Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Tue, 9 Apr 2024 10:12:38 +0200 Subject: [PATCH 53/80] stuff I forgot to push --- amcat4/api/query.py | 12 +++++------- amcat4/index.py | 18 +++++++++++++----- amcat4/query.py | 6 +++++- 3 files changed, 23 insertions(+), 13 deletions(-) diff --git a/amcat4/api/query.py b/amcat4/api/query.py index 9745704..eec895f 100644 --- a/amcat4/api/query.py +++ b/amcat4/api/query.py @@ -327,11 +327,7 @@ def query_aggregate_post( } -@app_query.post( - "/{index}/tags_update", - status_code=status.HTTP_204_NO_CONTENT, - response_class=Response, -) +@app_query.post("/{index}/tags_update") def query_update_tags( index: str, action: Literal["add", "remove"] = Body(None, description="Action (add or remove) on tags"), @@ -367,5 +363,7 @@ def query_update_tags( if isinstance(ids, (str, int)): ids = [ids] - update_tag_query(indices, action, field, tag, _standardize_queries(queries), _standardize_filters(filters), ids) - return + update_result = update_tag_query( + indices, action, field, tag, _standardize_queries(queries), _standardize_filters(filters), ids + ) + return update_result diff --git a/amcat4/index.py b/amcat4/index.py index b81e3fe..ae09572 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -529,15 +529,20 @@ def delete_document(index: str, doc_id: str): def update_by_query(index: str | list[str], script: str, query: dict, params: dict | None = None): script_dict = dict(source=script, lang="painless", params=params or {}) - test = es().update_by_query(index=index, script=script_dict, **query, refresh=True) + result = es().update_by_query(index=index, script=script_dict, **query, refresh=True) + return dict(updated=result["updated"], total=result["total"]) TAG_SCRIPTS = dict( add=""" if (ctx._source[params.field] == null) { ctx._source[params.field] = [params.tag] - } else if (!ctx._source[params.field].contains(params.tag)) { - ctx._source[params.field].add(params.tag) + } else { + if (ctx._source[params.field].contains(params.tag)) { + ctx.op = 'noop'; + } else { + ctx._source[params.field].add(params.tag) + } } """, remove=""" @@ -546,7 +551,10 @@ def update_by_query(index: str | list[str], script: str, query: dict, params: di if (ctx._source[params.field].size() == 0) { ctx._source.remove(params.field); } - }""", + } else { + ctx.op = 'noop'; + } + """, ) @@ -554,4 +562,4 @@ def update_tag_by_query(index: str | list[str], action: Literal["add", "remove"] create_or_verify_tag_field(index, field) script = TAG_SCRIPTS[action] params = dict(field=field, tag=tag) - update_by_query(index, script, query, params) + return update_by_query(index, script, query, params) diff --git a/amcat4/query.py b/amcat4/query.py index 1a17dcd..4f63df9 100644 --- a/amcat4/query.py +++ b/amcat4/query.py @@ -13,6 +13,8 @@ Literal, ) +from urllib3 import Retry + from amcat4.models import FieldSpec, FilterSpec, SortSpec from .date_mappings import mappings @@ -171,6 +173,7 @@ def query_documents( for s in sort: for k, v in s.items(): kwargs["sort"].append({k: dict(v)}) + if scroll_id: result = es().scroll(scroll_id=scroll_id, **kwargs) if not result["hits"]["hits"]: @@ -268,4 +271,5 @@ def update_tag_query( """Add or remove tags using a query""" body = build_body(queries, filters, ids=ids) - update_tag_by_query(index, action, body, field, tag) + update_result = update_tag_by_query(index, action, body, field, tag) + return update_result From 380e505b87b67397cb21fbd5b79abcc16d8c2bf8 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Tue, 9 Apr 2024 12:38:10 +0200 Subject: [PATCH 54/80] tests working with new identifier rules --- amcat4/api/index.py | 4 +- amcat4/fields.py | 12 ++++++ amcat4/index.py | 68 ++++++++++++++++----------------- amcat4/models.py | 14 +------ tests/conftest.py | 10 ++--- tests/test_api_documents.py | 3 +- tests/test_api_query.py | 23 ++++++------ tests/test_elastic.py | 75 +++++++++++++++++++++++++++++++++---- 8 files changed, 132 insertions(+), 77 deletions(-) diff --git a/amcat4/api/index.py b/amcat4/api/index.py index db4664a..3c54a84 100644 --- a/amcat4/api/index.py +++ b/amcat4/api/index.py @@ -180,7 +180,7 @@ def upload_documents( ix: str, documents: Annotated[list[dict[str, Any]], Body(description="The documents to upload")], fields: Annotated[ - dict[str, ElasticType | CreateField] | None, + dict[str, FieldType | CreateField] | None, Body( description="If a field in documents does not yet exist, you can create it on the spot. " "If you only need to specify the type, and use the default settings, " @@ -206,7 +206,7 @@ def upload_documents( check_role(user, index.Role.WRITER, ix) else: check_role(user, index.Role.ADMIN, ix) - return index.upload_documents(ix, documents, fields, operation, return_ids=False) + return index.upload_documents(ix, documents, fields, operation) @app_index.get("/{ix}/documents/{docid}") diff --git a/amcat4/fields.py b/amcat4/fields.py index 28f9ea1..f9060b5 100644 --- a/amcat4/fields.py +++ b/amcat4/fields.py @@ -164,6 +164,8 @@ def create_fields(index: str, fields: Mapping[str, FieldType | CreateField]): current_fields = get_fields(index) sfields = _standardize_createfields(fields) + old_identifiers = any(f.identifier for f in current_fields.values()) + new_identifiers = False for field, settings in sfields.items(): if settings.elastic_type is not None: @@ -189,6 +191,9 @@ def create_fields(index: str, fields: Mapping[str, FieldType | CreateField]): continue # if field does not exist, we add it to both the mapping and the system index + if settings.identifier: + new_identifiers = True + mapping[field] = {"type": elastic_type} if settings.type in ["date"]: mapping[field]["format"] = "strict_date_optional_time" @@ -202,7 +207,14 @@ def create_fields(index: str, fields: Mapping[str, FieldType | CreateField]): ) check_forbidden_type(current_fields[field], settings.type) + if new_identifiers: + # new identifiers are only allowed if the index had identifiers, or if it is a new index (i.e. no documents) + has_docs = es().count(index=index)["count"] > 0 + if has_docs and not old_identifiers: + raise ValueError("Cannot add identifiers. Index already has documents with no identifiers.") + if len(mapping) > 0: + # if there are new identifiers, check whether this is allowed first es().indices.put_mapping(index=index, properties=mapping) es().update( index=get_settings().system_index, diff --git a/amcat4/index.py b/amcat4/index.py index ae09572..886effb 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -33,7 +33,9 @@ """ import collections +from dataclasses import field from enum import IntEnum +import functools import logging from typing import Any, Iterable, Mapping, Optional, Literal @@ -408,28 +410,11 @@ def create_id(document: dict, field_settings: dict[str, Field]) -> str: identifiers = [k for k, v in field_settings.items() if v.identifier == True] if len(identifiers) == 0: - # if no identifiers specified, id is hash of entire document - # (we could also decide that in this case we let elastic create a uuid) - hash_values = document - else: - # if identifiers specified, we concatenate the values of these fields, and hash them if - # the string exceeds 512 characters (the maximum length of an id in elastic) - # we only use the fields that are present in the document, and sort them alphabetically, - # so that the id is the same after more identifiers are added. - id_fields = sorted(set(identifiers) & set(document.keys())) - - if not id_fields: - raise ValueError(f"None of the identifier fields {identifiers} are present in the document") - if len(id_fields) == 1: - return str(document[id_fields[0]]) - - id = "|".join(f"{k}={str(document[k])}" for k in id_fields) - if len(id.encode("utf-8")) < 500: - return id - - hash_values = {k: document.get(k) for k in id_fields} + raise ValueError("Can only create id if identifiers are specified") - hash_str = json.dumps(hash_values, sort_keys=True, ensure_ascii=True, default=str).encode("ascii") + id_keys = sorted(set(identifiers) & set(document.keys())) + id_fields = {k: document[k] for k in id_keys} + hash_str = json.dumps(id_fields, sort_keys=True, ensure_ascii=True, default=str).encode("ascii") m = hashlib.sha224() m.update(hash_str) return m.hexdigest() @@ -439,9 +424,8 @@ def upload_documents( index: str, documents: list[dict[str, Any]], fields: Mapping[str, FieldType | CreateField] | None = None, - op_type: Literal["index", "update"] = "index", - return_ids=True, - raise_on_error=True, + op_type: Literal["create", "update"] = "create", + raise_on_error=False, ): """ Upload documents to this index @@ -456,21 +440,38 @@ def upload_documents( def es_actions(index, documents, op_type): field_settings = get_fields(index) + has_identifiers = any(field.identifier for field in field_settings.values()) for document in documents: + doc = dict() + action = {"_op_type": op_type, "_index": index} + for key in document.keys(): + if key in field_settings: + doc[key] = coerce_type(document[key], field_settings[key].type) + else: + if key != "_id": + raise ValueError(f"Field '{key}' is not yet specified") + if key == "_id": - # WvA: Do we really wish to disallow this? - raise ValueError("You cannot directly set the '_id' field in a document.") - if key not in field_settings: - raise ValueError(f"The type for field '{key}' is not yet specified") - document[key] = coerce_type(document[key], field_settings[key].type) + if has_identifiers: + identifiers = ", ".join([name for name, field in field_settings.items() if field.identifier]) + raise ValueError(f"This index uses identifier ({identifiers}), so you cannot set the _id directly.") + action["_id"] = document["_id"] + else: + if has_identifiers: + action["_id"] = create_id(document, field_settings) + ## if no id is given, elasticsearch creates a cool unique one - id = create_id(document, field_settings) # https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html if op_type == "update": - yield {"_op_type": op_type, "_index": index, "_id": id, "doc": document, "doc_as_upsert": True} + if "_id" not in action: + raise ValueError("Update requires _id") + action["doc"] = doc + action["doc_as_upsert"] = True else: - yield {"_op_type": op_type, "_index": index, "_id": id, **document} + action.update(doc) + + yield action actions = list(es_actions(index, documents, op_type)) try: @@ -488,9 +489,6 @@ def es_actions(index, documents, op_type): e.args = e.args + (f"First error: {reason}",) raise - if return_ids: - ids = [doc["_id"] for doc in actions] - return dict(ids=ids, successes=successes, failures=failures) return dict(successes=successes, failures=failures) diff --git a/amcat4/models.py b/amcat4/models.py index c04ae04..302f0fa 100644 --- a/amcat4/models.py +++ b/amcat4/models.py @@ -7,19 +7,7 @@ FieldType = Literal[ - "text", - "date", - "boolean", - "keyword", - "number", - "integer", - "object", - "vector", - "geo_point", - "url", - "image_url", - "tag", - "json", + "text", "date", "boolean", "keyword", "number", "integer", "object", "vector", "geo_point", "image_url", "tag", "json" ] ElasticType = Literal[ "text", diff --git a/tests/conftest.py b/tests/conftest.py index 5a5a8dd..556ff8a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -142,14 +142,13 @@ def upload(index: str, docs: list[dict[str, Any]], fields: dict[str, FieldType | """ res = upload_documents(index, docs, fields) refresh_index(index) - return res["ids"] TEST_DOCUMENTS = [ - {"id": 0, "cat": "a", "subcat": "x", "i": 1, "date": "2018-01-01", "text": "this is a text", "title": "title"}, - {"id": 1, "cat": "a", "subcat": "x", "i": 2, "date": "2018-02-01", "text": "a test text", "title": "title"}, + {"_id": 0, "cat": "a", "subcat": "x", "i": 1, "date": "2018-01-01", "text": "this is a text", "title": "title"}, + {"_id": 1, "cat": "a", "subcat": "x", "i": 2, "date": "2018-02-01", "text": "a test text", "title": "title"}, { - "id": 2, + "_id": 2, "cat": "a", "subcat": "y", "i": 11, @@ -158,7 +157,7 @@ def upload(index: str, docs: list[dict[str, Any]], fields: dict[str, FieldType | "title": "bla", }, { - "id": 3, + "_id": 3, "cat": "b", "subcat": "y", "i": 31, @@ -175,7 +174,6 @@ def populate_index(index): index, TEST_DOCUMENTS, fields={ - "id": CreateField(type="integer", identifier=True), "text": "text", "title": "text", "date": "date", diff --git a/tests/test_api_documents.py b/tests/test_api_documents.py index feff654..612a219 100644 --- a/tests/test_api_documents.py +++ b/tests/test_api_documents.py @@ -27,9 +27,8 @@ def test_documents(client, index, user): f"index/{index}/documents", user=user, json={ - "documents": [{"id": "id", "title": "a title", "text": "text", "date": "2020-01-01"}], + "documents": [{"_id": "id", "title": "a title", "text": "text", "date": "2020-01-01"}], "fields": { - "id": {"type": "keyword", "identifier": True}, "title": {"type": "text"}, "text": {"type": "text"}, "date": {"type": "date"}, diff --git a/tests/test_api_query.py b/tests/test_api_query.py index 100c539..4c9f592 100644 --- a/tests/test_api_query.py +++ b/tests/test_api_query.py @@ -28,8 +28,7 @@ def qi(**query_string): assert qi(filters={"cat": {"values": ["a"]}}) == {0, 1, 2} # Can we request specific fields? - all_fields = {"_id", "id", "cat", "subcat", "i", "date", "text", "title"} - print(q()[0]) + all_fields = {"_id", "cat", "subcat", "i", "date", "text", "title"} assert set(q()[0].keys()) == all_fields assert set(q(fields=["cat"])[0].keys()) == {"_id", "cat"} assert set(q(fields=["date", "title"])[0].keys()) == {"_id", "date", "title"} @@ -190,30 +189,32 @@ def tags(): set_role(index_docs, user, Role.WRITER) assert tags() == {} - post_json( + res = post_json( client, f"/index/{index_docs}/tags_update", user=user, - expected=204, + expected=200, json=dict(action="add", field="tag", tag="x", filters={"cat": "a"}), ) - refresh_index(index_docs) + assert res["updated"] == 3 + # should refresh before returning + # refresh_index(index_docs) assert tags() == {"0": ["x"], "1": ["x"], "2": ["x"]} - post_json( + res = post_json( client, f"/index/{index_docs}/tags_update", user=user, - expected=204, + expected=200, json=dict(action="remove", field="tag", tag="x", queries=["text"]), ) - refresh_index(index_docs) + assert res["updated"] == 2 assert tags() == {"2": ["x"]} - post_json( + res = post_json( client, f"/index/{index_docs}/tags_update", user=user, - expected=204, + expected=200, json=dict(action="add", field="tag", tag="y", ids=["1", "2"]), ) - refresh_index(index_docs) + assert res["updated"] == 2 assert tags() == {"1": ["y"], "2": ["x", "y"]} diff --git a/tests/test_elastic.py b/tests/test_elastic.py index 8b31268..97646df 100644 --- a/tests/test_elastic.py +++ b/tests/test_elastic.py @@ -1,6 +1,8 @@ from datetime import datetime from re import I +import pytest + from amcat4.index import ( refresh_index, upload_documents, @@ -45,7 +47,7 @@ def test_data_coerced(index): def test_fields(index): """Can we get the fields from an index""" - create_fields(index, {"title": "text", "date": "date", "text": "text", "url": "url"}) + create_fields(index, {"title": "text", "date": "date", "text": "text", "url": "keyword"}) fields = get_fields(index) assert set(fields.keys()) == {"title", "date", "text", "url"} assert fields["title"].type == "text" @@ -96,27 +98,84 @@ def tags(): assert tags() == {"1": ["x"], "2": ["y"], "3": ["y"]} -def test_deduplication(index): +def test_upload_without_identifiers(index): doc = {"title": "titel", "text": "text", "date": datetime(2020, 1, 1)} - upload_documents(index, [doc], fields={"title": "text", "text": "text", "date": "date"}) + res = upload_documents(index, [doc], fields={"title": "text", "text": "text", "date": "date"}) + assert res["successes"] == 1 _assert_n(index, 1) - upload_documents(index, [doc]) + + # this doesnt identify duplicates + res = upload_documents(index, [doc]) + assert res["successes"] == 1 + _assert_n(index, 2) + + +def test_upload_with_explicit_ids(index): + doc = {"_id": "1", "title": "titel", "text": "text", "date": datetime(2020, 1, 1)} + res = upload_documents(index, [doc], fields={"title": "text", "text": "text", "date": "date"}) + assert res["successes"] == 1 + + # this does skip docs with same id + res = upload_documents(index, [doc]) + assert res["successes"] == 0 _assert_n(index, 1) -def test_identifier_deduplication(index): +def test_upload_with_identifiers(index): doc = {"url": "http://", "text": "text"} - upload_documents(index, [doc], fields={"url": CreateField(type="keyword", identifier=True), "text": "text"}) + res = upload_documents(index, [doc], fields={"url": CreateField(type="keyword", identifier=True), "text": "text"}) + assert res["successes"] == 1 _assert_n(index, 1) doc2 = {"url": "http://", "text": "text2"} - upload_documents(index, [doc2]) + res = upload_documents(index, [doc2]) + assert res["successes"] == 0 _assert_n(index, 1) doc3 = {"url": "http://2", "text": "text"} - upload_documents(index, [doc3]) + res = upload_documents(index, [doc3]) + assert res["successes"] == 1 _assert_n(index, 2) + # cannot upload explicit id if identifiers are used + doc4 = {"_id": "1", "url": "http://", "text": "text"} + with pytest.raises(ValueError): + upload_documents(index, [doc4]) + + +def test_invalid_adding_identifiers(index): + # identifiers can only be added if (1) the index already uses identifiers or (2) the index is still empty (no docs) + doc = {"text": "text"} + upload_documents(index, [doc], fields={"text": "text"}) + refresh_index(index) + + # adding an identifier to an existing index should fail + doc = {"url": "http://", "text": "text"} + with pytest.raises(ValueError): + upload_documents(index, [doc], fields={"url": CreateField(type="keyword", identifier=True)}) + + +def test_valid_adding_identifiers(index): + # adding an identifier to an empty index should succeed + doc = {"text": "text"} + upload_documents(index, [doc], fields={"text": CreateField(type="text", identifier=True)}) + refresh_index(index) + _assert_n(index, 1) + + # adding an identifier to an existing index should succeed if the index already has identifiers + doc = {"url": "http://", "text": "text"} + res = upload_documents(index, [doc], fields={"url": CreateField(type="keyword", identifier=True)}) + + # the document should have been added because its not a full duplicate (in first doc url was empty) + assert res["successes"] == 1 + + # both the identifier for the first doc and the second doc should still work, so the following docs are + # both duplicates + doc1 = {"text": "text"} + doc2 = {"url": "http://", "text": "text"} + res = upload_documents(index, [doc1, doc2]) + assert res["successes"] == 0 + def _assert_n(index, n): refresh_index(index) From 1bfe1c3051fea17dca36e61986e2291c8c7be461 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Tue, 9 Apr 2024 12:43:50 +0200 Subject: [PATCH 55/80] one extra test for good riddance --- tests/test_elastic.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/test_elastic.py b/tests/test_elastic.py index 97646df..4985a89 100644 --- a/tests/test_elastic.py +++ b/tests/test_elastic.py @@ -156,13 +156,10 @@ def test_invalid_adding_identifiers(index): def test_valid_adding_identifiers(index): - # adding an identifier to an empty index should succeed doc = {"text": "text"} upload_documents(index, [doc], fields={"text": CreateField(type="text", identifier=True)}) - refresh_index(index) - _assert_n(index, 1) - # adding an identifier to an existing index should succeed if the index already has identifiers + # adding an additional identifier to an existing index should succeed if the index already has identifiers doc = {"url": "http://", "text": "text"} res = upload_documents(index, [doc], fields={"url": CreateField(type="keyword", identifier=True)}) @@ -176,6 +173,14 @@ def test_valid_adding_identifiers(index): res = upload_documents(index, [doc1, doc2]) assert res["successes"] == 0 + # the order of adding identifiers doesn't matter. a document having just the url uses only the url as identifier + doc = {"url": "http://new"} + res = upload_documents(index, [doc]) + assert res["successes"] == 1 + # second time its a duplicate + res = upload_documents(index, [doc]) + assert res["successes"] == 0 + def _assert_n(index, n): refresh_index(index) From 80fbc5ac8426a905bd1010946c9ad472c7ee9920 Mon Sep 17 00:00:00 2001 From: Wouter van Atteveldt Date: Wed, 10 Apr 2024 15:40:22 +1000 Subject: [PATCH 56/80] Add multimedia support (wip) --- amcat4/__main__.py | 14 +----- amcat4/config.py | 4 +- amcat4/multimedia.py | 99 ++++++++++++++++++++++++++++++++++++++++ setup.py | 1 + tests/conftest.py | 18 ++++++-- tests/test_multimedia.py | 29 ++++++++++++ 6 files changed, 149 insertions(+), 16 deletions(-) create mode 100644 amcat4/multimedia.py create mode 100644 tests/test_multimedia.py diff --git a/amcat4/__main__.py b/amcat4/__main__.py index 81d1b7b..eb37dc8 100644 --- a/amcat4/__main__.py +++ b/amcat4/__main__.py @@ -143,10 +143,6 @@ def create_env(args): env = base_env() if args.admin_email: env["amcat4_admin_email"] = args.admin_email - if args.admin_password: - env["amcat4_admin_password"] = args.admin_password - if args.no_admin_password: - env["amcat4_admin_password"] = "" with open(".env", "w") as f: for key, val in env.items(): f.write(f"{key}={val}\n") @@ -166,9 +162,6 @@ def add_admin(args): def list_users(_args): - admin_password = get_settings().admin_password - if admin_password: - print("ADMIN : admin (password set via environment AMCAT4_ADMIN_PASSWORD)") users = list_global_users() # sorted changes the output type of list_global_users? @@ -176,8 +169,8 @@ def list_users(_args): if users: for user, role in users.items(): print(f"{role.name:10}: {user}") - if not (users or admin_password): - print("(No users defined yet, set AMCAT4_ADMIN_PASSWORD in environment use add-admin to add users by email)") + if not users: + print("(No users defined yet, use add-admin to add users by email)") def config_amcat(args): @@ -267,9 +260,6 @@ def main(): p = subparsers.add_parser("create-env", help="Create the .env file with a random secret key") p.add_argument("-a", "--admin_email", help="The email address of the admin user.") - p.add_argument("-p", "--admin_password", help="The password of the built-in admin user.") - p.add_argument("-P", "--no-admin_password", action="store_true", help="Disable admin password") - p.set_defaults(func=create_env) p = subparsers.add_parser("config", help="Configure amcat4 settings in an interactive menu.") diff --git a/amcat4/config.py b/amcat4/config.py index ecb4bdc..5364df8 100644 --- a/amcat4/config.py +++ b/amcat4/config.py @@ -110,7 +110,9 @@ class Settings(BaseSettings): ), ] = None - admin_password: Annotated[str | None, Field()] = None + minio_host: Annotated[str | None, Field()] = None + minio_access_key: Annotated[str | None, Field()] = None + minio_secret_key: Annotated[str | None, Field()] = None @model_validator(mode="after") def set_ssl(self: Any) -> "Settings": diff --git a/amcat4/multimedia.py b/amcat4/multimedia.py new file mode 100644 index 0000000..086e3d9 --- /dev/null +++ b/amcat4/multimedia.py @@ -0,0 +1,99 @@ +""" +Multimedia features for AmCAT + +AmCAT can link to a minio/S3 object store to provide access to multimedia content attached to documents. +The object store needs to be configured in the server settings. +""" + +import datetime +from io import BytesIO +from typing import Optional +from venv import create +from amcat4.config import get_settings +from minio import Minio, S3Error +from minio.deleteobjects import DeleteObject +from minio.datatypes import PostPolicy +import functools + + +def get_minio() -> Minio: + result = connect_minio() + if result is None: + raise ValueError("Could not connect to minio") + return result + + +@functools.lru_cache() +def connect_minio() -> Optional[Minio]: + try: + return _connect_minio() + except Exception as e: + raise Exception(f"Cannot connect to minio {get_settings().minio_host!r}: {e}") + + +def _connect_minio() -> Optional[Minio]: + settings = get_settings() + if settings.minio_host is None: + return None + if settings.minio_secret_key is None or settings.minio_access_key is None: + raise ValueError("minio_access_key or minio_secret_key not specified") + return Minio(settings.minio_host, secure=False, access_key=settings.minio_access_key, secret_key=settings.minio_secret_key) + + +def bucket_name(index: str) -> str: + return index.replace("_", "-") + + +def get_bucket(minio: Minio, index: str, create_if_needed=True): + """ + Get the bucket name for this index. If create_if_needed is True, create the bucket if it doesn't exist. + Returns the bucket name, or "" if it doesn't exist and create_if_needed is False. + """ + bucket = bucket_name(index) + if not minio.bucket_exists(bucket): + if not create_if_needed: + return "" + minio.make_bucket(bucket) + return bucket + + +def list_multimedia_objects(index: str): + minio = get_minio() + bucket = get_bucket(minio, index, create_if_needed=False) + if not bucket: + return [] + for object in minio.list_objects(bucket_name(index), recursive=True): + yield dict(key=object.object_name) + + +def delete_bucket(minio: Minio, index: str): + bucket = get_bucket(minio, index, create_if_needed=False) + if not bucket: + return + to_delete = [DeleteObject(x.object_name) for x in minio.list_objects(bucket, recursive=True) if x.object_name] + errors = list(minio.remove_objects(bucket, to_delete)) + if errors: + raise Exception(f"Error on deleting objects: {errors}") + minio.remove_bucket(bucket) + + +def add_multimedia_object(index: str, key: str, bytes: bytes): + minio = get_minio() + bucket = get_bucket(minio, index) + data = BytesIO(bytes) + minio.put_object(bucket, key, data, len(bytes)) + + +def presigned_post(index: str, key_prefix, days_valid=1): + minio = get_minio() + bucket = get_bucket(minio, index) + policy = PostPolicy(bucket, expiration=datetime.datetime.now() + datetime.timedelta(days=days_valid)) + policy.add_starts_with_condition("key", key_prefix) + url = f"http://{get_settings().minio_host}/{bucket}" + return url, minio.presigned_post_policy(policy) + + +def presigned_get(index: str, key, days_valid=1): + minio = get_minio() + bucket = get_bucket(minio, index) + return minio.presigned_get_object(bucket, key, expires=datetime.timedelta(days=days_valid)) diff --git a/setup.py b/setup.py index 3e4360a..ab895cd 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,7 @@ "requests", "class_doc", "mypy", + "minio", ], extras_require={"dev": ["pytest", "mypy", "flake8", "responses", "pre-commit", "types-requests"]}, entry_points={"console_scripts": ["amcat4 = amcat4.__main__:main"]}, diff --git a/tests/conftest.py b/tests/conftest.py index 556ff8a..1574702 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,7 @@ import responses from fastapi.testclient import TestClient -from amcat4 import api # noqa: E402 +from amcat4 import api, multimedia # noqa: E402 from amcat4.config import get_settings, AuthOptions from amcat4.elastic import es from amcat4.index import ( @@ -16,7 +16,7 @@ set_global_role, upload_documents, ) -from amcat4.models import CreateField, ElasticType, FieldType +from amcat4.models import CreateField, FieldType from tests.middlecat_keypair import PUBLIC_KEY UNITS = [ @@ -33,7 +33,9 @@ def mock_middlecat(): get_settings().middlecat_url = "http://localhost:5000" get_settings().host = "http://localhost:3000" - with responses.RequestsMock(assert_all_requests_are_fired=False) as resp: + minio = get_settings().minio_host + passthru = (f"http://{minio}",) if minio else () + with responses.RequestsMock(passthru_prefixes=passthru, assert_all_requests_are_fired=False) as resp: resp.get("http://localhost:5000/api/configuration", json={"public_key": PUBLIC_KEY}) yield None @@ -212,3 +214,13 @@ def index_many(): @pytest.fixture() def app(): return api.app + + +@pytest.fixture() +def minio(): + minio = multimedia.connect_minio() + if not minio: + pytest.skip("No minio connected, skipping multimedia tests") + for index in ["amcat4_unittest_index"]: + multimedia.delete_bucket(minio, index) + return minio diff --git a/tests/test_multimedia.py b/tests/test_multimedia.py new file mode 100644 index 0000000..890dc64 --- /dev/null +++ b/tests/test_multimedia.py @@ -0,0 +1,29 @@ +from io import BytesIO +import os +import requests +from amcat4 import multimedia + + +def test_upload_get_multimedia(minio, index): + assert list(multimedia.list_multimedia_objects(index)) == [] + multimedia.add_multimedia_object(index, "test", b"bytes") + assert {o["key"] for o in multimedia.list_multimedia_objects(index)} == {"test"} + + +def test_presigned_form(minio, index): + assert list(multimedia.list_multimedia_objects(index)) == [] + bytes = os.urandom(32) + key = "image.png" + url, form_data = multimedia.presigned_post(index, "") + res = requests.post( + url=url, + data={"key": key, **form_data}, + files={"file": BytesIO(bytes)}, + ) + res.raise_for_status() + assert {o["key"] for o in multimedia.list_multimedia_objects(index)} == {"image.png"} + + url = multimedia.presigned_get(index, key) + res = requests.get(url) + res.raise_for_status() + assert res.content == bytes From f5f2cb4837c24281666c1a25f235ed39347056b1 Mon Sep 17 00:00:00 2001 From: Wouter van Atteveldt Date: Wed, 10 Apr 2024 21:40:50 +1000 Subject: [PATCH 57/80] Add multimedia API endpoints --- amcat4/api/__init__.py | 3 ++ amcat4/api/multimedia.py | 44 ++++++++++++++++++++++ amcat4/multimedia.py | 15 ++++---- tests/test_api_multimedia.py | 73 ++++++++++++++++++++++++++++++++++++ tests/test_multimedia.py | 4 +- 5 files changed, 130 insertions(+), 9 deletions(-) create mode 100644 amcat4/api/multimedia.py create mode 100644 tests/test_api_multimedia.py diff --git a/amcat4/api/__init__.py b/amcat4/api/__init__.py index 2cba783..65df346 100644 --- a/amcat4/api/__init__.py +++ b/amcat4/api/__init__.py @@ -9,6 +9,7 @@ from amcat4.api.info import app_info from amcat4.api.query import app_query from amcat4.api.users import app_users +from amcat4.api.multimedia import app_multimedia app = FastAPI( @@ -26,12 +27,14 @@ "and the core process of getting units and posting annotations", ), dict(name="annotator guest", description="Annotator module endpoints for unregistered guests"), + dict(name="multimedia", description="Endpoints for multimedia support"), ], ) app.include_router(app_info) app.include_router(app_users) app.include_router(app_index) app.include_router(app_query) +app.include_router(app_multimedia) app.add_middleware( CORSMiddleware, allow_origins=["*"], diff --git a/amcat4/api/multimedia.py b/amcat4/api/multimedia.py new file mode 100644 index 0000000..eeb5418 --- /dev/null +++ b/amcat4/api/multimedia.py @@ -0,0 +1,44 @@ +import itertools +from typing import Optional +from fastapi import APIRouter, Depends + +from amcat4 import index, multimedia +from amcat4.api.auth import authenticated_user, check_role + + +app_multimedia = APIRouter(prefix="/index/{ix}/multimedia", tags=["multimedia"]) + + +@app_multimedia.get("/presigned_get") +def presigned_get(ix: str, key: str, user: str = Depends(authenticated_user)): + check_role(user, index.Role.READER, ix) + url = multimedia.presigned_get(ix, key) + return dict(url=url) + + +@app_multimedia.get("/presigned_post") +def presigned_post(ix: str, user: str = Depends(authenticated_user)): + check_role(user, index.Role.WRITER, ix) + url, form_data = multimedia.presigned_post(ix) + return dict(url=url, form_data=form_data) + + +@app_multimedia.get("/list") +def list_multimedia( + ix: str, + n: int = 10, + prefix: Optional[str] = None, + start_after: Optional[str] = None, + recursive=False, + presigned_get=False, + user: str = Depends(authenticated_user), +): + def process(obj): + result = dict(key=obj.object_name) + if presigned_get: + result["presigned_get"] = multimedia.presigned_get(ix, obj.object_name) + return result + + check_role(user, index.Role.READER, ix) + objects = multimedia.list_multimedia_objects(ix, prefix, start_after, recursive) + return [process(obj) for obj in itertools.islice(objects, n)] diff --git a/amcat4/multimedia.py b/amcat4/multimedia.py index 086e3d9..73bc0dc 100644 --- a/amcat4/multimedia.py +++ b/amcat4/multimedia.py @@ -7,12 +7,12 @@ import datetime from io import BytesIO -from typing import Optional +from typing import Iterable, Optional from venv import create from amcat4.config import get_settings from minio import Minio, S3Error from minio.deleteobjects import DeleteObject -from minio.datatypes import PostPolicy +from minio.datatypes import PostPolicy, Object import functools @@ -57,13 +57,14 @@ def get_bucket(minio: Minio, index: str, create_if_needed=True): return bucket -def list_multimedia_objects(index: str): +def list_multimedia_objects( + index: str, prefix: Optional[str] = None, start_after: Optional[str] = None, recursive=True +) -> Iterable[Object]: minio = get_minio() bucket = get_bucket(minio, index, create_if_needed=False) if not bucket: - return [] - for object in minio.list_objects(bucket_name(index), recursive=True): - yield dict(key=object.object_name) + return + yield from minio.list_objects(bucket_name(index), prefix=prefix, start_after=start_after, recursive=recursive) def delete_bucket(minio: Minio, index: str): @@ -84,7 +85,7 @@ def add_multimedia_object(index: str, key: str, bytes: bytes): minio.put_object(bucket, key, data, len(bytes)) -def presigned_post(index: str, key_prefix, days_valid=1): +def presigned_post(index: str, key_prefix: str = "", days_valid=1): minio = get_minio() bucket = get_bucket(minio, index) policy = PostPolicy(bucket, expiration=datetime.datetime.now() + datetime.timedelta(days=days_valid)) diff --git a/tests/test_api_multimedia.py b/tests/test_api_multimedia.py new file mode 100644 index 0000000..7660a6a --- /dev/null +++ b/tests/test_api_multimedia.py @@ -0,0 +1,73 @@ +from fastapi.testclient import TestClient +import requests +from amcat4 import multimedia +from amcat4.index import set_role, Role +from tests.tools import post_json, build_headers, get_json, check + + +def _get_names(client: TestClient, index, user, **kargs): + res = client.get(f"index/{index}/multimedia/list", params=kargs, headers=build_headers(user)) + res.raise_for_status() + return {obj["key"] for obj in res.json()} + + +def test_authorisation(minio, client, index, user, reader): + check(client.get(f"index/{index}/multimedia/list"), 401) + check(client.get(f"index/{index}/multimedia/presigned_get", params=dict(key="")), 401) + check(client.get(f"index/{index}/multimedia/presigned_post"), 401) + + set_role(index, user, Role.METAREADER) + set_role(index, reader, Role.READER) + check(client.get(f"index/{index}/multimedia/list", headers=build_headers(user)), 401) + check(client.get(f"index/{index}/multimedia/presigned_get", params=dict(key=""), headers=build_headers(user)), 401) + check(client.get(f"index/{index}/multimedia/presigned_post", headers=build_headers(reader)), 401) + + +def test_post_get_list(minio, client, index, user): + set_role(index, user, Role.WRITER) + assert _get_names(client, index, user) == set() + post = client.get(f"index/{index}/multimedia/presigned_post", headers=build_headers(user)).json() + assert set(post.keys()) == {"url", "form_data"} + multimedia.add_multimedia_object(index, "test", b"bytes") + assert _get_names(client, index, user) == {"test"} + res = client.get(f"index/{index}/multimedia/presigned_get", headers=build_headers(user), params=dict(key="test")) + res.raise_for_status() + assert requests.get(res.json()["url"]).content == b"bytes" + + +def test_list_options(minio, client, index, reader): + set_role(index, reader, Role.READER) + multimedia.add_multimedia_object(index, "myfolder/a1", b"a1") + multimedia.add_multimedia_object(index, "myfolder/a2", b"a2") + multimedia.add_multimedia_object(index, "obj1", b"obj1") + multimedia.add_multimedia_object(index, "obj2", b"obj2") + multimedia.add_multimedia_object(index, "obj3", b"obj3") + multimedia.add_multimedia_object(index, "zzz", b"zzz") + + assert _get_names(client, index, reader) == {"obj1", "obj2", "obj3", "myfolder/", "zzz"} + assert _get_names(client, index, reader, recursive=True) == {"obj1", "obj2", "obj3", "myfolder/a1", "myfolder/a2", "zzz"} + assert _get_names(client, index, reader, prefix="obj") == {"obj1", "obj2", "obj3"} + assert _get_names(client, index, reader, prefix="myfolder/") == {"myfolder/a1", "myfolder/a2"} + assert _get_names(client, index, reader, prefix="myfolder/", presigned_get=True) == {"myfolder/a1", "myfolder/a2"} + res = client.get( + f"index/{index}/multimedia/list", params=dict(prefix="myfolder/", presigned_get=True), headers=build_headers(reader) + ) + res.raise_for_status() + urls = {o["key"]: o["presigned_get"] for o in res.json()} + assert requests.get(urls["myfolder/a1"]).content == b"a1" + + +def test_list_pagination(minio, client, index, reader): + set_role(index, reader, Role.READER) + ids = [f"obj_{i:02}" for i in range(15)] + for id in ids: + multimedia.add_multimedia_object(index, id, id.encode("utf-8")) + + # default page size is 10 + names = _get_names(client, index, reader) + assert names == set(ids[:10]) + more_names = _get_names(client, index, reader, start_after=ids[9]) + assert more_names == set(ids[10:]) + + names = _get_names(client, index, reader, n=5) + assert names == set(ids[:5]) diff --git a/tests/test_multimedia.py b/tests/test_multimedia.py index 890dc64..042b1a0 100644 --- a/tests/test_multimedia.py +++ b/tests/test_multimedia.py @@ -7,7 +7,7 @@ def test_upload_get_multimedia(minio, index): assert list(multimedia.list_multimedia_objects(index)) == [] multimedia.add_multimedia_object(index, "test", b"bytes") - assert {o["key"] for o in multimedia.list_multimedia_objects(index)} == {"test"} + assert {o.object_name for o in multimedia.list_multimedia_objects(index)} == {"test"} def test_presigned_form(minio, index): @@ -21,7 +21,7 @@ def test_presigned_form(minio, index): files={"file": BytesIO(bytes)}, ) res.raise_for_status() - assert {o["key"] for o in multimedia.list_multimedia_objects(index)} == {"image.png"} + assert {o.object_name for o in multimedia.list_multimedia_objects(index)} == {"image.png"} url = multimedia.presigned_get(index, key) res = requests.get(url) From b41797a30b501ca0abe995b0fb61957d9e32dd85 Mon Sep 17 00:00:00 2001 From: Wouter van Atteveldt Date: Thu, 11 Apr 2024 12:50:32 +1000 Subject: [PATCH 58/80] Return more info on multimedia objects --- amcat4/api/multimedia.py | 20 ++++++++++++++++---- amcat4/multimedia.py | 9 +++++++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/amcat4/api/multimedia.py b/amcat4/api/multimedia.py index eeb5418..65fa73c 100644 --- a/amcat4/api/multimedia.py +++ b/amcat4/api/multimedia.py @@ -4,7 +4,7 @@ from amcat4 import index, multimedia from amcat4.api.auth import authenticated_user, check_role - +from minio.datatypes import Object app_multimedia = APIRouter(prefix="/index/{ix}/multimedia", tags=["multimedia"]) @@ -31,11 +31,23 @@ def list_multimedia( start_after: Optional[str] = None, recursive=False, presigned_get=False, + metadata=False, user: str = Depends(authenticated_user), ): - def process(obj): - result = dict(key=obj.object_name) - if presigned_get: + def process(obj: Object): + if metadata and (not obj.is_dir) and obj.object_name: + obj = multimedia.stat_multimedia_object(ix, obj.object_name) + result: dict[str, object] = dict( + key=obj.object_name, + is_dir=obj.is_dir, + last_modified=obj.last_modified, + size=obj.size, + ) + if metadata: + result["metadata"] = (obj.metadata,) + result["content_type"] = (obj.content_type,) + + if presigned_get and not obj.is_dir: result["presigned_get"] = multimedia.presigned_get(ix, obj.object_name) return result diff --git a/amcat4/multimedia.py b/amcat4/multimedia.py index 73bc0dc..92060bd 100644 --- a/amcat4/multimedia.py +++ b/amcat4/multimedia.py @@ -7,6 +7,7 @@ import datetime from io import BytesIO +from multiprocessing import Value from typing import Iterable, Optional from venv import create from amcat4.config import get_settings @@ -67,6 +68,14 @@ def list_multimedia_objects( yield from minio.list_objects(bucket_name(index), prefix=prefix, start_after=start_after, recursive=recursive) +def stat_multimedia_object(index: str, key: str) -> Object: + minio = get_minio() + bucket = get_bucket(minio, index, create_if_needed=False) + if not bucket: + raise ValueError(f"Bucket for {index} does not exist") + return minio.stat_object(bucket, key) + + def delete_bucket(minio: Minio, index: str): bucket = get_bucket(minio, index, create_if_needed=False) if not bucket: From b2878c22c4ee47cc7e932a6646379afcc70b15a6 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Thu, 11 Apr 2024 15:27:33 +0200 Subject: [PATCH 59/80] made guest role at most writer (not admin) --- amcat4/api/index.py | 6 +++--- amcat4/index.py | 14 +++++++++++--- amcat4/models.py | 14 +++++++++++++- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/amcat4/api/index.py b/amcat4/api/index.py index 3c54a84..a79867d 100644 --- a/amcat4/api/index.py +++ b/amcat4/api/index.py @@ -84,7 +84,7 @@ class ChangeIndex(BaseModel): name: str | None = None description: str | None = None - guest_role: Literal["ADMIN", "WRITER", "READER", "METAREADER", "NONE"] | None = None + guest_role: Literal["WRITER", "READER", "METAREADER", "NONE"] | None = None archive: bool | None = None @@ -98,7 +98,7 @@ def modify_index(ix: str, data: ChangeIndex, user: str = Depends(authenticated_u User needs admin rights on the index """ check_role(user, index.Role.ADMIN, ix) - guest_role = index.Role[data.guest_role] if data.guest_role is not None else None + guest_role = index.GuestRole[data.guest_role] if data.guest_role is not None else None archived = None if data.archive is not None: d = index.get_index(ix) @@ -127,7 +127,7 @@ def view_index(ix: str, user: str = Depends(authenticated_user)): role = check_role(user, index.Role.METAREADER, ix, required_global_role=index.Role.WRITER) d = index.get_index(ix)._asdict() d["user_role"] = role.name - d["guest_role"] = index.Role(d.get("guest_role", 0)).name + d["guest_role"] = index.GuestRole(d.get("guest_role", 0)).name d["description"] = d.get("description", "") or "" d["name"] = d.get("name", "") or "" return d diff --git a/amcat4/index.py b/amcat4/index.py index 886effb..37af6a3 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -33,6 +33,7 @@ """ import collections +from curses import meta from dataclasses import field from enum import IntEnum import functools @@ -65,6 +66,13 @@ class Role(IntEnum): ADMIN = 40 +class GuestRole(IntEnum): + NONE = 0 + METAREADER = 10 + READER = 20 + WRITER = 30 + + ADMIN_USER = "_admin" GUEST_USER = "_guest" GLOBAL_ROLES = "_global" @@ -260,18 +268,18 @@ def set_global_role(email: str, role: Role | None): set_role(index=GLOBAL_ROLES, email=email, role=role) -def set_guest_role(index: str, guest_role: Optional[Role]): +def set_guest_role(index: str, guest_role: Optional[GuestRole]): """ Set the guest role for this index. Set to None to disallow guest access """ - modify_index(index, guest_role=Role.NONE if guest_role is None else guest_role) + modify_index(index, guest_role=GuestRole.NONE if guest_role is None else guest_role) def modify_index( index: str, name: Optional[str] = None, description: Optional[str] = None, - guest_role: Optional[Role] = None, + guest_role: Optional[GuestRole] = None, archived: Optional[str] = None, ): doc = dict( diff --git a/amcat4/models.py b/amcat4/models.py index 302f0fa..9129f88 100644 --- a/amcat4/models.py +++ b/amcat4/models.py @@ -7,7 +7,19 @@ FieldType = Literal[ - "text", "date", "boolean", "keyword", "number", "integer", "object", "vector", "geo_point", "image_url", "tag", "json" + "text", + "date", + "boolean", + "keyword", + "number", + "integer", + "object", + "vector", + "geo_point", + "image_url", + "tag", + "json", + "url", ] ElasticType = Literal[ "text", From f706e6b2ec1bd3868e63f80e3e61f9f9bf44083e Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Fri, 12 Apr 2024 18:02:03 +0200 Subject: [PATCH 60/80] some quick validation in list multimedia --- amcat4/api/multimedia.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/amcat4/api/multimedia.py b/amcat4/api/multimedia.py index 65fa73c..4d44085 100644 --- a/amcat4/api/multimedia.py +++ b/amcat4/api/multimedia.py @@ -34,6 +34,10 @@ def list_multimedia( metadata=False, user: str = Depends(authenticated_user), ): + recursive = str(recursive).lower() == "true" + metadata = str(metadata).lower() == "true" + presigned_get = str(presigned_get).lower() == "true" + def process(obj: Object): if metadata and (not obj.is_dir) and obj.object_name: obj = multimedia.stat_multimedia_object(ix, obj.object_name) @@ -47,7 +51,9 @@ def process(obj: Object): result["metadata"] = (obj.metadata,) result["content_type"] = (obj.content_type,) - if presigned_get and not obj.is_dir: + if presigned_get is True and not obj.is_dir: + if n > 10: + raise ValueError("Cannot provide presigned_get for more than 10 objects") result["presigned_get"] = multimedia.presigned_get(ix, obj.object_name) return result From c0c8272f9b43eabcb0e5d41267559607ce044d8e Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Sun, 14 Apr 2024 18:18:25 +0200 Subject: [PATCH 61/80] include content type in presigned_get (for rendering content) --- amcat4/api/multimedia.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/amcat4/api/multimedia.py b/amcat4/api/multimedia.py index 4d44085..4a48c58 100644 --- a/amcat4/api/multimedia.py +++ b/amcat4/api/multimedia.py @@ -13,7 +13,8 @@ def presigned_get(ix: str, key: str, user: str = Depends(authenticated_user)): check_role(user, index.Role.READER, ix) url = multimedia.presigned_get(ix, key) - return dict(url=url) + obj = multimedia.stat_multimedia_object(ix, key) + return dict(url=url, content_type=(obj.content_type,), size=obj.size) @app_multimedia.get("/presigned_post") From 79332b13e182f53036d63563c292b28c813b2cdd Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Mon, 15 Apr 2024 18:17:58 +0200 Subject: [PATCH 62/80] added video and image fields --- amcat4/api/index.py | 2 +- amcat4/fields.py | 3 ++- amcat4/models.py | 3 ++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/amcat4/api/index.py b/amcat4/api/index.py index a79867d..b37a552 100644 --- a/amcat4/api/index.py +++ b/amcat4/api/index.py @@ -15,7 +15,7 @@ from amcat4.index import refresh_system_index, remove_role, set_role from amcat4.fields import field_values, field_stats -from amcat4.models import CreateField, ElasticType, FieldType, UpdateField +from amcat4.models import CreateField, FieldType, UpdateField app_index = APIRouter(prefix="/index", tags=["index"]) diff --git a/amcat4/fields.py b/amcat4/fields.py index f9060b5..f9565d2 100644 --- a/amcat4/fields.py +++ b/amcat4/fields.py @@ -83,7 +83,8 @@ "vector": ["dense_vector"], "geo_point": ["geo_point"], "tag": ["keyword", "wildcard"], - "image_url": ["wildcard", "keyword", "constant_keyword", "text"], + "image": ["wildcard", "keyword", "constant_keyword", "text"], + "video": ["wildcard", "keyword", "constant_keyword", "text"], "url": ["wildcard", "keyword", "constant_keyword", "text"], "json": ["text"], } diff --git a/amcat4/models.py b/amcat4/models.py index 9129f88..0d800bc 100644 --- a/amcat4/models.py +++ b/amcat4/models.py @@ -16,7 +16,8 @@ "object", "vector", "geo_point", - "image_url", + "image", + "video", "tag", "json", "url", From 4dea6a059c8426c8cdabb117c0fcf89e8bf24ac9 Mon Sep 17 00:00:00 2001 From: Wouter van Atteveldt Date: Tue, 16 Apr 2024 14:35:41 +1000 Subject: [PATCH 63/80] First commit preprocessing --- amcat4/fields.py | 5 +- amcat4/models.py | 1 + amcat4/preprocessing/instruction.py | 50 ++++++++++++++++ amcat4/preprocessing/processor.py | 38 ++++++++++++ amcat4/preprocessing/task.py | 92 +++++++++++++++++++++++++++++ tests/test_preprocessing.py | 58 ++++++++++++++++++ 6 files changed, 240 insertions(+), 4 deletions(-) create mode 100644 amcat4/preprocessing/instruction.py create mode 100644 amcat4/preprocessing/processor.py create mode 100644 amcat4/preprocessing/task.py create mode 100644 tests/test_preprocessing.py diff --git a/amcat4/fields.py b/amcat4/fields.py index f9565d2..8553a99 100644 --- a/amcat4/fields.py +++ b/amcat4/fields.py @@ -15,17 +15,13 @@ """ import datetime -from hmac import new import json -from tabnanny import check from typing import Any, Iterator, Mapping, get_args, cast from elasticsearch import NotFoundError -from httpx import get # from amcat4.api.common import py2dict -from amcat4 import elastic from amcat4.config import get_settings from amcat4.elastic import es from amcat4.models import FieldType, CreateField, ElasticType, Field, UpdateField, FieldMetareaderAccess @@ -87,6 +83,7 @@ "video": ["wildcard", "keyword", "constant_keyword", "text"], "url": ["wildcard", "keyword", "constant_keyword", "text"], "json": ["text"], + "preprocess": ["object"], } diff --git a/amcat4/models.py b/amcat4/models.py index 0d800bc..b8e9ebd 100644 --- a/amcat4/models.py +++ b/amcat4/models.py @@ -21,6 +21,7 @@ "tag", "json", "url", + "preprocess", ] ElasticType = Literal[ "text", diff --git a/amcat4/preprocessing/instruction.py b/amcat4/preprocessing/instruction.py new file mode 100644 index 0000000..03b6331 --- /dev/null +++ b/amcat4/preprocessing/instruction.py @@ -0,0 +1,50 @@ +import copy +import functools +from typing import Any, Dict, Iterable, List, Optional, Tuple +from pydantic import BaseModel +import requests + +from amcat4.preprocessing.task import get_task + + +class PreprocessingArgument(BaseModel): + name: str + field: Optional[str] = None + value: Optional[str | int | bool | float | List[str] | List[int] | List[float]] = None + + +class PreprocessingOutput(BaseModel): + name: str + field: str + + +class PreprocessingInstruction(BaseModel): + field: str + task: str + endpoint: str + arguments: List[PreprocessingArgument] + outputs: List[PreprocessingOutput] + + def build_request(self, doc): + # TODO: validate that instruction is valid for task! + task = get_task(self.task) + if task.request.body != "json": + raise NotImplementedError() + if not task.request.template: + raise ValueError(f"Task {task.name} has json body but not template") + body = copy.deepcopy(task.request.template) + for argument in self.arguments: + param = task.get_parameter(argument.name) + if param.use_field == "yes": + value = doc.get(argument.field) + else: + value = argument.value + param.parsed.update(body, value) + + return requests.Request("POST", self.endpoint, json=body) + + def parse_output(self, output) -> Iterable[Tuple[str, Any]]: + task = get_task(self.task) + for arg in self.outputs: + o = task.get_output(arg.name) + yield arg.field, o.parsed.find(output)[0].value diff --git a/amcat4/preprocessing/processor.py b/amcat4/preprocessing/processor.py new file mode 100644 index 0000000..39c8aba --- /dev/null +++ b/amcat4/preprocessing/processor.py @@ -0,0 +1,38 @@ +import asyncio +from time import sleep +from requests import Session +from amcat4.elastic import es +from amcat4.index import update_document +from amcat4.preprocessing.instruction import PreprocessingInstruction +from amcat4.preprocessing.task import get_task + + +async def run_preprocessors(): + pass + + +async def run_preprocessor(index: str, instruction: PreprocessingInstruction): + # Q: it it better to repeat a simple "get n todo docs", or to iteratively scroll past all todo items? + while True: + docs = list(get_todo(index, instruction, size=10)) + for doc in docs: + process_doc(index, instruction, doc) + if len(docs) < 10: + # There were not enough todo items, so let's sleep + await asyncio.sleep(30) + + +def get_todo(index: str, instruction: PreprocessingInstruction, size=10): + fields = [arg.field for arg in instruction.arguments if arg.field] + q = dict(bool=dict(must_not=dict(exists=dict(field=instruction.field)))) + for doc in es().search(index=index, size=size, source_includes=fields, query=q)["hits"]["hits"]: + yield {"_id": doc["_id"], **doc["_source"]} + + +def process_doc(index: str, instruction: PreprocessingInstruction, doc: dict): + req = instruction.build_request(doc) + response = Session().send(req.prepare()) + response.raise_for_status() + result = dict(instruction.parse_output(response.json())) + result[instruction.field] = dict(status="done") + update_document(index, doc["_id"], result) diff --git a/amcat4/preprocessing/task.py b/amcat4/preprocessing/task.py new file mode 100644 index 0000000..0234951 --- /dev/null +++ b/amcat4/preprocessing/task.py @@ -0,0 +1,92 @@ +import functools +from multiprocessing import Value +from typing import Any, Dict, List, Literal, Optional +from pydantic import BaseModel +import jsonpath_ng + +""" +https://huggingface.co/docs/api-inference/detailed_parameters + +""" + + +class PreprocessingRequest(BaseModel): + body: Literal["json", "binary"] + template: Optional[dict] + + +class PreprocessingOutput(BaseModel): + name: str + type: str = "string" + path: str + + @functools.cached_property + def parsed(self) -> jsonpath_ng.JSONPath: + return jsonpath_ng.parse(self.path) + + +class PreprocessingParameter(PreprocessingOutput): + use_field: Literal["yes", "no"] = "no" + default: Optional[bool | str | int | float] = None + placeholder: Optional[str] = None + + +class PreprocessingEndpoint(BaseModel): + placeholder: str + domain: List[str] + + +class PreprocessingTask(BaseModel): + """Form for query metadata.""" + + name: str + endpoint: PreprocessingEndpoint + parameters: List[PreprocessingParameter] + outputs: List[PreprocessingOutput] + request: PreprocessingRequest + + def get_parameter(self, name) -> PreprocessingParameter: + # TODO should probably cache this + for param in self.parameters: + if param.name == name: + return param + raise ValueError(f"Parameter {name} not defined on task {self.name}") + + def get_output(self, name) -> PreprocessingOutput: + # TODO should probably cache this + for output in self.outputs: + if output.name == name: + return output + raise ValueError(f"Parameter {name} not defined on task {self.name}") + + +TASKS: List[PreprocessingTask] = [ + PreprocessingTask( + # https://huggingface.co/docs/api-inference/detailed_parameters#zero-shot-classification-task + name="HuggingFace Zero-Shot", + endpoint=PreprocessingEndpoint( + placeholder="https://api-inference.huggingface.co/models/facebook/bart-large-mnli", + domain=["hugginggace.co", "huggingfacecloud.com"], + ), + parameters=[ + PreprocessingParameter(name="input", type="string", use_field="yes", path="$.inputs"), + PreprocessingParameter( + name="candidate_labels", + type="string[]", + use_field="no", + placeholder="politics, sports", + path="$.parameters.candidate_labels", + ), + ], + outputs=[PreprocessingOutput(name="label", path="$.labels[0]")], + request=PreprocessingRequest(body="json", template={"inputs": "", "parameters": {"candidate_labels": ""}}), + ) +] + + +@functools.cache +def get_task(name): + for task in TASKS: + if task.name == name: + return task + raise ValueError(f"Task {task} not defined") diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py new file mode 100644 index 0000000..a2b0888 --- /dev/null +++ b/tests/test_preprocessing.py @@ -0,0 +1,58 @@ +import json + +import requests +from amcat4.fields import create_fields +from amcat4.index import get_document, refresh_index +from amcat4.preprocessing.instruction import PreprocessingInstruction +import responses + +from amcat4.preprocessing.processor import get_todo, process_doc +from tests.conftest import TEST_DOCUMENTS + +INSTRUCTION = dict( + field="preprocess_label", + task="HuggingFace Zero-Shot", + endpoint="https://api-inference.huggingface.co/models/facebook/bart-large-mnli", + arguments=[{"name": "input", "field": "text"}, {"name": "candidate_labels", "value": ["politics", "sports"]}], + outputs=[{"name": "label", "field": "class_label"}], +) + + +def test_build_request(): + + i = PreprocessingInstruction.model_validate_json(json.dumps(INSTRUCTION)) + doc = dict(text="Sample text") + req = i.build_request(doc) + assert req.url == INSTRUCTION["endpoint"] + assert req.json == dict(inputs=doc["text"], parameters=dict(candidate_labels=["politics", "sports"])) + + output = {"labels": ["politics", "sports"], "scores": [0.9, 0.1]} + with responses.RequestsMock(assert_all_requests_are_fired=True) as resp: + resp.post(i.endpoint, json=output) + result = requests.Session().send(req.prepare()) + result.raise_for_status() + update = dict(i.parse_output(result.json())) + assert update == dict(class_label="politics") + + +def test_preprocess(index_docs): + i = PreprocessingInstruction.model_validate_json(json.dumps(INSTRUCTION)) + create_fields(index_docs, {i.field: "preprocess"}) + todos = list(get_todo(index_docs, i)) + assert all(set(todo.keys()) == {"_id", "text"} for todo in todos) + assert {doc["_id"] for doc in todos} == {str(doc["_id"]) for doc in TEST_DOCUMENTS} + + todo = sorted(todos, key=lambda todo: todo["_id"])[0] + output = {"labels": ["politics", "sports"], "scores": [0.9, 0.1]} + + with responses.RequestsMock(assert_all_requests_are_fired=True) as resp: + resp.post(i.endpoint, json=output) + process_doc(index_docs, i, todo) + doc = get_document(index_docs, todo["_id"]) + assert doc[i.field] == {"status": "done"} + assert doc["class_label"] == "politics" + + refresh_index(index_docs) + + todos = list(get_todo(index_docs, i)) + assert {doc["_id"] for doc in todos} == {str(doc["_id"]) for doc in TEST_DOCUMENTS} - {todo["_id"]} From c1e733bf96972297a1e8d4b594905c057c29b1c2 Mon Sep 17 00:00:00 2001 From: Wouter van Atteveldt Date: Tue, 16 Apr 2024 15:07:40 +1000 Subject: [PATCH 64/80] Move preprocessing to async httpx --- amcat4/preprocessing/instruction.py | 5 +++-- amcat4/preprocessing/processor.py | 7 ++++--- tests/test_preprocessing.py | 28 +++++++++++++--------------- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/amcat4/preprocessing/instruction.py b/amcat4/preprocessing/instruction.py index 03b6331..e5fb2e5 100644 --- a/amcat4/preprocessing/instruction.py +++ b/amcat4/preprocessing/instruction.py @@ -1,6 +1,7 @@ import copy import functools from typing import Any, Dict, Iterable, List, Optional, Tuple +import httpx from pydantic import BaseModel import requests @@ -25,7 +26,7 @@ class PreprocessingInstruction(BaseModel): arguments: List[PreprocessingArgument] outputs: List[PreprocessingOutput] - def build_request(self, doc): + def build_request(self, doc) -> httpx.Request: # TODO: validate that instruction is valid for task! task = get_task(self.task) if task.request.body != "json": @@ -41,7 +42,7 @@ def build_request(self, doc): value = argument.value param.parsed.update(body, value) - return requests.Request("POST", self.endpoint, json=body) + return httpx.Request("POST", self.endpoint, json=body) def parse_output(self, output) -> Iterable[Tuple[str, Any]]: task = get_task(self.task) diff --git a/amcat4/preprocessing/processor.py b/amcat4/preprocessing/processor.py index 39c8aba..8efd2ab 100644 --- a/amcat4/preprocessing/processor.py +++ b/amcat4/preprocessing/processor.py @@ -1,5 +1,6 @@ import asyncio from time import sleep +from httpx import AsyncClient from requests import Session from amcat4.elastic import es from amcat4.index import update_document @@ -16,7 +17,7 @@ async def run_preprocessor(index: str, instruction: PreprocessingInstruction): while True: docs = list(get_todo(index, instruction, size=10)) for doc in docs: - process_doc(index, instruction, doc) + await process_doc(index, instruction, doc) if len(docs) < 10: # There were not enough todo items, so let's sleep await asyncio.sleep(30) @@ -29,9 +30,9 @@ def get_todo(index: str, instruction: PreprocessingInstruction, size=10): yield {"_id": doc["_id"], **doc["_source"]} -def process_doc(index: str, instruction: PreprocessingInstruction, doc: dict): +async def process_doc(index: str, instruction: PreprocessingInstruction, doc: dict): req = instruction.build_request(doc) - response = Session().send(req.prepare()) + response = await AsyncClient().send(req) response.raise_for_status() result = dict(instruction.parse_output(response.json())) result[instruction.field] = dict(status="done") diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index a2b0888..6b748a1 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -1,5 +1,8 @@ +from pytest_httpx import HTTPXMock import json +import httpx +import pytest import requests from amcat4.fields import create_fields from amcat4.index import get_document, refresh_index @@ -19,23 +22,22 @@ def test_build_request(): - i = PreprocessingInstruction.model_validate_json(json.dumps(INSTRUCTION)) doc = dict(text="Sample text") req = i.build_request(doc) assert req.url == INSTRUCTION["endpoint"] - assert req.json == dict(inputs=doc["text"], parameters=dict(candidate_labels=["politics", "sports"])) + assert json.loads(req.content) == dict(inputs=doc["text"], parameters=dict(candidate_labels=["politics", "sports"])) + +def test_parse_result(): + i = PreprocessingInstruction.model_validate_json(json.dumps(INSTRUCTION)) output = {"labels": ["politics", "sports"], "scores": [0.9, 0.1]} - with responses.RequestsMock(assert_all_requests_are_fired=True) as resp: - resp.post(i.endpoint, json=output) - result = requests.Session().send(req.prepare()) - result.raise_for_status() - update = dict(i.parse_output(result.json())) - assert update == dict(class_label="politics") + update = dict(i.parse_output(output)) + assert update == dict(class_label="politics") -def test_preprocess(index_docs): +@pytest.mark.asyncio +async def test_preprocess(index_docs, httpx_mock: HTTPXMock): i = PreprocessingInstruction.model_validate_json(json.dumps(INSTRUCTION)) create_fields(index_docs, {i.field: "preprocess"}) todos = list(get_todo(index_docs, i)) @@ -43,16 +45,12 @@ def test_preprocess(index_docs): assert {doc["_id"] for doc in todos} == {str(doc["_id"]) for doc in TEST_DOCUMENTS} todo = sorted(todos, key=lambda todo: todo["_id"])[0] - output = {"labels": ["politics", "sports"], "scores": [0.9, 0.1]} - - with responses.RequestsMock(assert_all_requests_are_fired=True) as resp: - resp.post(i.endpoint, json=output) - process_doc(index_docs, i, todo) + httpx_mock.add_response(url=i.endpoint, json={"labels": ["politics", "sports"], "scores": [0.9, 0.1]}) + await process_doc(index_docs, i, todo) doc = get_document(index_docs, todo["_id"]) assert doc[i.field] == {"status": "done"} assert doc["class_label"] == "politics" refresh_index(index_docs) - todos = list(get_todo(index_docs, i)) assert {doc["_id"] for doc in todos} == {str(doc["_id"]) for doc in TEST_DOCUMENTS} - {todo["_id"]} From 3ed9acabdbb41d5d811e67123cfacc541f953748 Mon Sep 17 00:00:00 2001 From: Wouter van Atteveldt Date: Tue, 16 Apr 2024 15:31:34 +1000 Subject: [PATCH 65/80] Add process_documents function --- amcat4/preprocessing/processor.py | 18 +++++++++++------- tests/test_preprocessing.py | 10 +++++++++- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/amcat4/preprocessing/processor.py b/amcat4/preprocessing/processor.py index 8efd2ab..e511086 100644 --- a/amcat4/preprocessing/processor.py +++ b/amcat4/preprocessing/processor.py @@ -3,7 +3,7 @@ from httpx import AsyncClient from requests import Session from amcat4.elastic import es -from amcat4.index import update_document +from amcat4.index import refresh_index, update_document from amcat4.preprocessing.instruction import PreprocessingInstruction from amcat4.preprocessing.task import get_task @@ -12,18 +12,22 @@ async def run_preprocessors(): pass -async def run_preprocessor(index: str, instruction: PreprocessingInstruction): +async def process_documents(index: str, instruction: PreprocessingInstruction, size=100): + """ + Process all currently to-do documents in the index for this instruction. + Returns when it runs out of documents to do + """ # Q: it it better to repeat a simple "get n todo docs", or to iteratively scroll past all todo items? while True: - docs = list(get_todo(index, instruction, size=10)) + docs = list(get_todo(index, instruction, size=size)) for doc in docs: await process_doc(index, instruction, doc) - if len(docs) < 10: - # There were not enough todo items, so let's sleep - await asyncio.sleep(30) + if len(docs) < size: + return + refresh_index(index) -def get_todo(index: str, instruction: PreprocessingInstruction, size=10): +def get_todo(index: str, instruction: PreprocessingInstruction, size=100): fields = [arg.field for arg in instruction.arguments if arg.field] q = dict(bool=dict(must_not=dict(exists=dict(field=instruction.field)))) for doc in es().search(index=index, size=size, source_includes=fields, query=q)["hits"]["hits"]: diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 6b748a1..ddec7f2 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -9,7 +9,7 @@ from amcat4.preprocessing.instruction import PreprocessingInstruction import responses -from amcat4.preprocessing.processor import get_todo, process_doc +from amcat4.preprocessing.processor import get_todo, process_doc, process_documents, run_preprocessors from tests.conftest import TEST_DOCUMENTS INSTRUCTION = dict( @@ -54,3 +54,11 @@ async def test_preprocess(index_docs, httpx_mock: HTTPXMock): refresh_index(index_docs) todos = list(get_todo(index_docs, i)) assert {doc["_id"] for doc in todos} == {str(doc["_id"]) for doc in TEST_DOCUMENTS} - {todo["_id"]} + + # run all preprocessors in a loop + await process_documents(index_docs, i, size=2) + refresh_index(index_docs) + todos = list(get_todo(index_docs, i)) + assert len(todos) == 0 + # There should be one call per document! + assert len(httpx_mock.get_requests()) == len(TEST_DOCUMENTS) From f975f9e64fe1bf7aff47aad13e5551756b7391f6 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Tue, 16 Apr 2024 11:53:11 +0200 Subject: [PATCH 66/80] added audio --- amcat4/api/multimedia.py | 14 ++++++++++---- amcat4/fields.py | 3 ++- amcat4/models.py | 1 + 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/amcat4/api/multimedia.py b/amcat4/api/multimedia.py index 4a48c58..0ffd239 100644 --- a/amcat4/api/multimedia.py +++ b/amcat4/api/multimedia.py @@ -1,10 +1,11 @@ import itertools from typing import Optional -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, HTTPException from amcat4 import index, multimedia from amcat4.api.auth import authenticated_user, check_role from minio.datatypes import Object +from minio.error import S3Error app_multimedia = APIRouter(prefix="/index/{ix}/multimedia", tags=["multimedia"]) @@ -12,9 +13,14 @@ @app_multimedia.get("/presigned_get") def presigned_get(ix: str, key: str, user: str = Depends(authenticated_user)): check_role(user, index.Role.READER, ix) - url = multimedia.presigned_get(ix, key) - obj = multimedia.stat_multimedia_object(ix, key) - return dict(url=url, content_type=(obj.content_type,), size=obj.size) + try: + url = multimedia.presigned_get(ix, key) + obj = multimedia.stat_multimedia_object(ix, key) + return dict(url=url, content_type=(obj.content_type,), size=obj.size) + except S3Error as e: + if e.code == "NoSuchKey": + raise HTTPException(status_code=404, detail=f"multimedia file {key} not found") + raise HTTPException(status_code=404, detail=e.message) @app_multimedia.get("/presigned_post") diff --git a/amcat4/fields.py b/amcat4/fields.py index 8553a99..61f2ca1 100644 --- a/amcat4/fields.py +++ b/amcat4/fields.py @@ -81,6 +81,7 @@ "tag": ["keyword", "wildcard"], "image": ["wildcard", "keyword", "constant_keyword", "text"], "video": ["wildcard", "keyword", "constant_keyword", "text"], + "audio": ["wildcard", "keyword", "constant_keyword", "text"], "url": ["wildcard", "keyword", "constant_keyword", "text"], "json": ["text"], "preprocess": ["object"], @@ -132,7 +133,7 @@ def coerce_type(value: Any, type: FieldType): """ if type == "date" and isinstance(value, datetime.date): return value.isoformat() - if type in ["text", "tag", "image_url", "date"]: + if type in ["text", "tag", "image", "video", "audio", "date"]: return str(value) if type in ["boolean"]: return bool(value) diff --git a/amcat4/models.py b/amcat4/models.py index b8e9ebd..6eb8d87 100644 --- a/amcat4/models.py +++ b/amcat4/models.py @@ -18,6 +18,7 @@ "geo_point", "image", "video", + "audio", "tag", "json", "url", From b3122907b5db5bc885899b13e9cae9267aa1cb4f Mon Sep 17 00:00:00 2001 From: Wouter van Atteveldt Date: Tue, 16 Apr 2024 22:18:19 +1000 Subject: [PATCH 67/80] Add preprocessing API endpoints --- amcat4/api/__init__.py | 3 ++ amcat4/api/preprocessing.py | 26 +++++++++++ amcat4/preprocessing/instruction.py | 71 ++++++++--------------------- amcat4/preprocessing/models.py | 50 ++++++++++++++++++++ amcat4/preprocessing/processor.py | 42 ++++++++++++++--- amcat4/preprocessing/task.py | 4 ++ tests/test_api_index.py | 4 +- tests/test_api_preprocessing.py | 47 +++++++++++++++++++ tests/test_preprocessing.py | 21 +++++++-- 9 files changed, 203 insertions(+), 65 deletions(-) create mode 100644 amcat4/api/preprocessing.py create mode 100644 amcat4/preprocessing/models.py create mode 100644 tests/test_api_preprocessing.py diff --git a/amcat4/api/__init__.py b/amcat4/api/__init__.py index 65df346..920d7ec 100644 --- a/amcat4/api/__init__.py +++ b/amcat4/api/__init__.py @@ -10,6 +10,7 @@ from amcat4.api.query import app_query from amcat4.api.users import app_users from amcat4.api.multimedia import app_multimedia +from amcat4.api.preprocessing import app_preprocessing app = FastAPI( @@ -28,6 +29,7 @@ ), dict(name="annotator guest", description="Annotator module endpoints for unregistered guests"), dict(name="multimedia", description="Endpoints for multimedia support"), + dict(name="preprocessing", description="Endpoints for preprocessing support"), ], ) app.include_router(app_info) @@ -35,6 +37,7 @@ app.include_router(app_index) app.include_router(app_query) app.include_router(app_multimedia) +app.include_router(app_preprocessing) app.add_middleware( CORSMiddleware, allow_origins=["*"], diff --git a/amcat4/api/preprocessing.py b/amcat4/api/preprocessing.py new file mode 100644 index 0000000..d4d8ce7 --- /dev/null +++ b/amcat4/api/preprocessing.py @@ -0,0 +1,26 @@ +from fastapi import APIRouter, Depends, Response, status + +from amcat4 import index +from amcat4.api.auth import authenticated_user, check_role +from amcat4.preprocessing.instruction import PreprocessingInstruction, get_instructions, add_instruction +from amcat4.preprocessing.task import get_tasks + + +app_preprocessing = APIRouter(tags=["preprocessing"]) + + +@app_preprocessing.get("/preprocessing_tasks") +def list_tasks(): + return [t.model_dump() for t in get_tasks()] + + +@app_preprocessing.get("/index/{ix}/preprocessing") +def list_instructions(ix: str, user: str = Depends(authenticated_user)): + check_role(user, index.Role.READER, ix) + return get_instructions(ix) + + +@app_preprocessing.post("/index/{ix}/preprocessing", status_code=status.HTTP_204_NO_CONTENT, response_class=Response) +async def post_instruction(ix: str, instruction: PreprocessingInstruction, user: str = Depends(authenticated_user)): + check_role(user, index.Role.WRITER, ix) + add_instruction(ix, instruction) diff --git a/amcat4/preprocessing/instruction.py b/amcat4/preprocessing/instruction.py index e5fb2e5..0d45407 100644 --- a/amcat4/preprocessing/instruction.py +++ b/amcat4/preprocessing/instruction.py @@ -1,51 +1,20 @@ -import copy -import functools -from typing import Any, Dict, Iterable, List, Optional, Tuple -import httpx -from pydantic import BaseModel -import requests - -from amcat4.preprocessing.task import get_task - - -class PreprocessingArgument(BaseModel): - name: str - field: Optional[str] = None - value: Optional[str | int | bool | float | List[str] | List[int] | List[float]] = None - - -class PreprocessingOutput(BaseModel): - name: str - field: str - - -class PreprocessingInstruction(BaseModel): - field: str - task: str - endpoint: str - arguments: List[PreprocessingArgument] - outputs: List[PreprocessingOutput] - - def build_request(self, doc) -> httpx.Request: - # TODO: validate that instruction is valid for task! - task = get_task(self.task) - if task.request.body != "json": - raise NotImplementedError() - if not task.request.template: - raise ValueError(f"Task {task.name} has json body but not template") - body = copy.deepcopy(task.request.template) - for argument in self.arguments: - param = task.get_parameter(argument.name) - if param.use_field == "yes": - value = doc.get(argument.field) - else: - value = argument.value - param.parsed.update(body, value) - - return httpx.Request("POST", self.endpoint, json=body) - - def parse_output(self, output) -> Iterable[Tuple[str, Any]]: - task = get_task(self.task) - for arg in self.outputs: - o = task.get_output(arg.name) - yield arg.field, o.parsed.find(output)[0].value +from amcat4.config import get_settings +from amcat4.elastic import es +from amcat4.fields import create_fields, get_fields +from amcat4.preprocessing.models import PreprocessingInstruction +from amcat4.preprocessing.processor import get_manager + + +def get_instructions(index: str): + res = es().get(index=get_settings().system_index, id=index, source="preprocessing") + return res["_source"].get("preprocessing", []) + + +def add_instruction(index: str, instruction: PreprocessingInstruction): + if instruction.field in get_fields(index): + raise ValueError("Field {instruction.field} already exists in index {index}") + current = get_instructions(index) + current.append(instruction.model_dump()) + create_fields(index, {instruction.field: "preprocess"}) + es().update(index=get_settings().system_index, id=index, doc=dict(preprocessing=current)) + get_manager().add_preprocessor(index, instruction) diff --git a/amcat4/preprocessing/models.py b/amcat4/preprocessing/models.py new file mode 100644 index 0000000..eaa08ea --- /dev/null +++ b/amcat4/preprocessing/models.py @@ -0,0 +1,50 @@ +import copy +from typing import Any, Iterable, List, Optional, Tuple + +import httpx +from pydantic import BaseModel + +from amcat4.preprocessing.task import get_task + + +class PreprocessingArgument(BaseModel): + name: str + field: Optional[str] = None + value: Optional[str | int | bool | float | List[str] | List[int] | List[float]] = None + + +class PreprocessingOutput(BaseModel): + name: str + field: str + + +class PreprocessingInstruction(BaseModel): + field: str + task: str + endpoint: str + arguments: List[PreprocessingArgument] + outputs: List[PreprocessingOutput] + + def build_request(self, doc) -> httpx.Request: + # TODO: validate that instruction is valid for task! + task = get_task(self.task) + if task.request.body != "json": + raise NotImplementedError() + if not task.request.template: + raise ValueError(f"Task {task.name} has json body but not template") + body = copy.deepcopy(task.request.template) + for argument in self.arguments: + param = task.get_parameter(argument.name) + if param.use_field == "yes": + value = doc.get(argument.field) + else: + value = argument.value + param.parsed.update(body, value) + + return httpx.Request("POST", self.endpoint, json=body) + + def parse_output(self, output) -> Iterable[Tuple[str, Any]]: + task = get_task(self.task) + for arg in self.outputs: + o = task.get_output(arg.name) + yield arg.field, o.parsed.find(output)[0].value diff --git a/amcat4/preprocessing/processor.py b/amcat4/preprocessing/processor.py index e511086..309136e 100644 --- a/amcat4/preprocessing/processor.py +++ b/amcat4/preprocessing/processor.py @@ -1,15 +1,42 @@ import asyncio -from time import sleep +from functools import cache +import logging +from typing import Dict, Tuple +from anyio import create_task_group from httpx import AsyncClient -from requests import Session from amcat4.elastic import es from amcat4.index import refresh_index, update_document -from amcat4.preprocessing.instruction import PreprocessingInstruction -from amcat4.preprocessing.task import get_task +from amcat4.preprocessing.models import PreprocessingInstruction -async def run_preprocessors(): - pass +class PreprocessorManager: + SINGLETON = None + + def __init__(self): + self.preprocessors: Dict[Tuple[str, str], asyncio.Task] = {} + self.task_group = create_task_group() + + def add_preprocessor(self, index: str, instruction: PreprocessingInstruction): + self.preprocessors[index, instruction.field] = asyncio.create_task(run_processor_loop(index, instruction)) + + def stop_preprocessor(self, index: str, field: str): + self.preprocessors[index, field].cancel() + + def stop(self): + for task in self.preprocessors.values(): + task.cancel() + + +@cache +def get_manager(): + return PreprocessorManager() + + +async def run_processor_loop(index, instruction: PreprocessingInstruction): + while True: + logging.info(f"Preprocessing documents for {index}.{instruction.field}") + await process_documents(index, instruction) + await asyncio.sleep(1) async def process_documents(index: str, instruction: PreprocessingInstruction, size=100): @@ -27,7 +54,7 @@ async def process_documents(index: str, instruction: PreprocessingInstruction, s refresh_index(index) -def get_todo(index: str, instruction: PreprocessingInstruction, size=100): +def get_todo(index: str, instruction, size=100): fields = [arg.field for arg in instruction.arguments if arg.field] q = dict(bool=dict(must_not=dict(exists=dict(field=instruction.field)))) for doc in es().search(index=index, size=size, source_includes=fields, query=q)["hits"]["hits"]: @@ -35,6 +62,7 @@ def get_todo(index: str, instruction: PreprocessingInstruction, size=100): async def process_doc(index: str, instruction: PreprocessingInstruction, doc: dict): + # TODO catch errors and add to status field, rather than raising req = instruction.build_request(doc) response = await AsyncClient().send(req) response.raise_for_status() diff --git a/amcat4/preprocessing/task.py b/amcat4/preprocessing/task.py index 0234951..772d12a 100644 --- a/amcat4/preprocessing/task.py +++ b/amcat4/preprocessing/task.py @@ -90,3 +90,7 @@ def get_task(name): if task.name == name: return task raise ValueError(f"Task {task} not defined") + + +def get_tasks(): + return TASKS diff --git a/tests/test_api_index.py b/tests/test_api_index.py index ed90496..a285f99 100644 --- a/tests/test_api_index.py +++ b/tests/test_api_index.py @@ -2,7 +2,7 @@ from amcat4 import elastic -from amcat4.index import get_guest_role, Role, set_guest_role, set_role, remove_role +from amcat4.index import GuestRole, get_guest_role, Role, set_guest_role, set_role, remove_role from amcat4.fields import update_fields from tests.tools import build_headers, post_json, get_json, check, refresh @@ -231,7 +231,7 @@ def test_name_description(client, index, index_name, user, admin): assert (get_json(client, f"/index/{index}", user=user) or {})["name"] == "test" set_role(index, user, None) check(client.get(f"/index/{index}", headers=build_headers(user)), 401) - set_guest_role(index, Role.METAREADER) + set_guest_role(index, GuestRole.METAREADER) assert (get_json(client, f"/index/{index}", user=user) or {})["name"] == "test" check( diff --git a/tests/test_api_preprocessing.py b/tests/test_api_preprocessing.py new file mode 100644 index 0000000..73aa25a --- /dev/null +++ b/tests/test_api_preprocessing.py @@ -0,0 +1,47 @@ +import asyncio +import pytest +from amcat4.index import Role, get_document, refresh_index, set_role +from tests.conftest import TEST_DOCUMENTS +from tests.test_preprocessing import INSTRUCTION +from tests.tools import build_headers, check + + +def test_get_tasks(client): + # UPDATE after we make a proper 'task store' + res = client.get("/preprocessing_tasks") + res.raise_for_status() + assert any(task["name"] == "HuggingFace Zero-Shot" for task in res.json()) + + +def test_auth(client, index, user): + check(client.get(f"/index/{index}/preprocessing"), 401) + check(client.post(f"/index/{index}/preprocessing", json=INSTRUCTION), 401) + set_role(index, user, Role.READER) + + check(client.get(f"/index/{index}/preprocessing", headers=build_headers(user=user)), 200) + check(client.post(f"/index/{index}/preprocessing", json=INSTRUCTION, headers=build_headers(user=user)), 401) + + +@pytest.mark.asyncio +async def test_post_get_instructions(client, user, index_docs, httpx_mock): + set_role(index_docs, user, Role.WRITER) + res = client.get(f"/index/{index_docs}/preprocessing", headers=build_headers(user=user)) + res.raise_for_status() + assert len(res.json()) == 0 + + httpx_mock.add_response(url=INSTRUCTION["endpoint"], json={"labels": ["games", "sports"], "scores": [0.9, 0.1]}) + + res = client.post(f"/index/{index_docs}/preprocessing", headers=build_headers(user=user), json=INSTRUCTION) + res.raise_for_status() + refresh_index(index_docs) + res = client.get(f"/index/{index_docs}/preprocessing", headers=build_headers(user=user)) + res.raise_for_status() + assert {item["field"] for item in res.json()} == {INSTRUCTION["field"]} + + while len(httpx_mock.get_requests()) < len(TEST_DOCUMENTS): + await asyncio.sleep(0.1) + await asyncio.sleep(0.1) + assert all(get_document(index_docs, doc["_id"])["class_label"] == "games" for doc in TEST_DOCUMENTS) + + # Cannot re-add the same field + check(client.post(f"/index/{index_docs}/preprocessing", json=INSTRUCTION, headers=build_headers(user=user)), 400) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index ddec7f2..984909b 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -1,15 +1,14 @@ +import asyncio +import time from pytest_httpx import HTTPXMock import json -import httpx import pytest -import requests from amcat4.fields import create_fields from amcat4.index import get_document, refresh_index -from amcat4.preprocessing.instruction import PreprocessingInstruction -import responses +from amcat4.preprocessing.instruction import PreprocessingInstruction, add_instruction -from amcat4.preprocessing.processor import get_todo, process_doc, process_documents, run_preprocessors +from amcat4.preprocessing.processor import PreprocessorManager, get_todo, process_doc, process_documents from tests.conftest import TEST_DOCUMENTS INSTRUCTION = dict( @@ -62,3 +61,15 @@ async def test_preprocess(index_docs, httpx_mock: HTTPXMock): assert len(todos) == 0 # There should be one call per document! assert len(httpx_mock.get_requests()) == len(TEST_DOCUMENTS) + + +@pytest.mark.asyncio +async def test_preprocess_loop(index_docs, httpx_mock: HTTPXMock): + i = PreprocessingInstruction.model_validate_json(json.dumps(INSTRUCTION)) + httpx_mock.add_response(url=i.endpoint, json={"labels": ["politics", "sports"], "scores": [0.9, 0.1]}) + add_instruction(index_docs, i) + while len(httpx_mock.get_requests()) < len(TEST_DOCUMENTS): + await asyncio.sleep(0.1) + await asyncio.sleep(0.5) + assert len(httpx_mock.get_requests()) == len(TEST_DOCUMENTS) + assert all(get_document(index_docs, doc["_id"])["class_label"] == "politics" for doc in TEST_DOCUMENTS) From 101434838d6dcf55cdaa27de62e6d7ba0a8e0913 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Wed, 17 Apr 2024 00:56:14 +0200 Subject: [PATCH 68/80] added jsonpath dependency --- amcat4/api/preprocessing.py | 2 ++ amcat4/preprocessing/task.py | 2 +- setup.py | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/amcat4/api/preprocessing.py b/amcat4/api/preprocessing.py index d4d8ce7..f0958bf 100644 --- a/amcat4/api/preprocessing.py +++ b/amcat4/api/preprocessing.py @@ -11,6 +11,8 @@ @app_preprocessing.get("/preprocessing_tasks") def list_tasks(): + print(get_tasks()) + print([t.model_dump() for t in get_tasks()]) return [t.model_dump() for t in get_tasks()] diff --git a/amcat4/preprocessing/task.py b/amcat4/preprocessing/task.py index 772d12a..4654a2e 100644 --- a/amcat4/preprocessing/task.py +++ b/amcat4/preprocessing/task.py @@ -66,7 +66,7 @@ def get_output(self, name) -> PreprocessingOutput: name="HuggingFace Zero-Shot", endpoint=PreprocessingEndpoint( placeholder="https://api-inference.huggingface.co/models/facebook/bart-large-mnli", - domain=["hugginggace.co", "huggingfacecloud.com"], + domain=["huggingface.co", "huggingfacecloud.com"], ), parameters=[ PreprocessingParameter(name="input", type="string", use_field="yes", path="$.inputs"), diff --git a/setup.py b/setup.py index ab895cd..d19636e 100644 --- a/setup.py +++ b/setup.py @@ -32,6 +32,7 @@ "class_doc", "mypy", "minio", + "jsonpath_ng", ], extras_require={"dev": ["pytest", "mypy", "flake8", "responses", "pre-commit", "types-requests"]}, entry_points={"console_scripts": ["amcat4 = amcat4.__main__:main"]}, From bffbfdceb4305c4606f521d3f0b70a6310154971 Mon Sep 17 00:00:00 2001 From: Wouter van Atteveldt Date: Wed, 17 Apr 2024 14:38:18 +1000 Subject: [PATCH 69/80] Improve preprocessing status and error handling; add header parameters --- amcat4/__main__.py | 7 ++- amcat4/api/__init__.py | 13 +++++ amcat4/api/index.py | 9 ---- amcat4/api/preprocessing.py | 23 +++++++-- amcat4/fields.py | 6 ++- amcat4/index.py | 4 ++ amcat4/preprocessing/instruction.py | 19 ++++++-- amcat4/preprocessing/models.py | 14 ++++-- amcat4/preprocessing/processor.py | 74 +++++++++++++++++++++++++---- amcat4/preprocessing/task.py | 8 ++++ tests/test_api_preprocessing.py | 12 +++-- 11 files changed, 151 insertions(+), 38 deletions(-) diff --git a/amcat4/__main__.py b/amcat4/__main__.py index eb37dc8..1ae98de 100644 --- a/amcat4/__main__.py +++ b/amcat4/__main__.py @@ -8,6 +8,7 @@ import json import logging import os +from pathlib import Path import secrets import sys from typing import Any @@ -18,6 +19,8 @@ import uvicorn from pydantic.fields import FieldInfo +from uvicorn.config import LOGGING_CONFIG + from amcat4 import index from amcat4.config import get_settings, AuthOptions, validate_settings @@ -69,7 +72,8 @@ def run(args): ) if ping(): logging.info(f"Connect to elasticsearch {get_settings().elastic_host}") - uvicorn.run("amcat4.api:app", host="0.0.0.0", reload=not args.nodebug, port=args.port) + log_config = "logging.yml" if Path("logging.yml").exists() else LOGGING_CONFIG + uvicorn.run("amcat4.api:app", host="0.0.0.0", reload=not args.nodebug, port=args.port, log_config=log_config) def val(val_or_list): @@ -283,7 +287,6 @@ def main(): logging.basicConfig(format="[%(levelname)-7s:%(name)-15s] %(message)s", level=logging.INFO) es_logger = logging.getLogger("elasticsearch") es_logger.setLevel(logging.WARNING) - args.func(args) diff --git a/amcat4/api/__init__.py b/amcat4/api/__init__.py index 920d7ec..b6f9c9f 100644 --- a/amcat4/api/__init__.py +++ b/amcat4/api/__init__.py @@ -1,5 +1,7 @@ """AmCAT4 API.""" +from contextlib import asynccontextmanager +import logging from fastapi import FastAPI, Request from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware @@ -11,6 +13,16 @@ from amcat4.api.users import app_users from amcat4.api.multimedia import app_multimedia from amcat4.api.preprocessing import app_preprocessing +from amcat4.preprocessing.processor import start_processors + + +@asynccontextmanager +async def lifespan(app: FastAPI): + try: + start_processors() + except: + logging.exception("Error on initializing preprocessing") + yield app = FastAPI( @@ -31,6 +43,7 @@ dict(name="multimedia", description="Endpoints for multimedia support"), dict(name="preprocessing", description="Endpoints for preprocessing support"), ], + lifespan=lifespan, ) app.include_router(app_info) app.include_router(app_users) diff --git a/amcat4/api/index.py b/amcat4/api/index.py index b37a552..dd8b34b 100644 --- a/amcat4/api/index.py +++ b/amcat4/api/index.py @@ -160,17 +160,8 @@ def archive_index( def delete_index(ix: str, user: str = Depends(authenticated_user)): """Delete the index.""" check_role(user, index.Role.ADMIN, ix) - min_archived_before_delete = 7 # days - try: - d = index.get_index(ix) - if d.archived is None or (datetime.now() - d.archived).days < min_archived_before_delete: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Can only delete an index after it has been archived for at least {min_archived_before_delete} days", - ) index.delete_index(ix) - except index.IndexDoesNotExist: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Index {ix} does not exist") diff --git a/amcat4/api/preprocessing.py b/amcat4/api/preprocessing.py index f0958bf..d35fe6e 100644 --- a/amcat4/api/preprocessing.py +++ b/amcat4/api/preprocessing.py @@ -1,18 +1,19 @@ -from fastapi import APIRouter, Depends, Response, status +import logging +from fastapi import APIRouter, Depends, HTTPException, Response, status from amcat4 import index from amcat4.api.auth import authenticated_user, check_role -from amcat4.preprocessing.instruction import PreprocessingInstruction, get_instructions, add_instruction +from amcat4.preprocessing.instruction import PreprocessingInstruction, get_instruction, get_instructions, add_instruction +from amcat4.preprocessing.processor import get_counts, get_manager from amcat4.preprocessing.task import get_tasks +logger = logging.getLogger("amcat4.preprocessing") app_preprocessing = APIRouter(tags=["preprocessing"]) @app_preprocessing.get("/preprocessing_tasks") def list_tasks(): - print(get_tasks()) - print([t.model_dump() for t in get_tasks()]) return [t.model_dump() for t in get_tasks()] @@ -26,3 +27,17 @@ def list_instructions(ix: str, user: str = Depends(authenticated_user)): async def post_instruction(ix: str, instruction: PreprocessingInstruction, user: str = Depends(authenticated_user)): check_role(user, index.Role.WRITER, ix) add_instruction(ix, instruction) + + +@app_preprocessing.get("/index/{ix}/preprocessing/{field}") +async def get_instruction_details(ix: str, field: str, user: str = Depends(authenticated_user)): + check_role(user, index.Role.WRITER, ix) + i = get_instruction(ix, field) + if i is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Preprocessing instruction for field {field} on index {ix} not found", + ) + state = get_manager().get_status(ix, field) + counts = get_counts(ix, field) + return dict(instruction=i, status=state, counts=counts) diff --git a/amcat4/fields.py b/amcat4/fields.py index 61f2ca1..1a5717a 100644 --- a/amcat4/fields.py +++ b/amcat4/fields.py @@ -16,7 +16,7 @@ import datetime import json -from typing import Any, Iterator, Mapping, get_args, cast +from typing import Any, Iterator, Literal, Mapping, get_args, cast from elasticsearch import NotFoundError @@ -192,8 +192,9 @@ def create_fields(index: str, fields: Mapping[str, FieldType | CreateField]): # if field does not exist, we add it to both the mapping and the system index if settings.identifier: new_identifiers = True - mapping[field] = {"type": elastic_type} + if settings.type == "preprocess": + mapping[field]["properties"] = dict(status=dict(type="keyword"), error=dict(type="text", index=False)) if settings.type in ["date"]: mapping[field]["format"] = "strict_date_optional_time" @@ -214,6 +215,7 @@ def create_fields(index: str, fields: Mapping[str, FieldType | CreateField]): if len(mapping) > 0: # if there are new identifiers, check whether this is allowed first + print(json.dumps(mapping, indent=2)) es().indices.put_mapping(index=index, properties=mapping) es().update( index=get_settings().system_index, diff --git a/amcat4/index.py b/amcat4/index.py index 37af6a3..72332e2 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -225,6 +225,10 @@ def deregister_index(index: str, ignore_missing=False) -> None: raise else: refresh_index(system_index) + # Stop preprocessing loops on this index + from amcat4.preprocessing.processor import get_manager + + get_manager().stop_preprocessors(index) def _roles_from_elastic(roles: list[dict]) -> dict[str, Role]: diff --git a/amcat4/preprocessing/instruction.py b/amcat4/preprocessing/instruction.py index 0d45407..6671d71 100644 --- a/amcat4/preprocessing/instruction.py +++ b/amcat4/preprocessing/instruction.py @@ -1,3 +1,4 @@ +from typing import Iterable, Optional from amcat4.config import get_settings from amcat4.elastic import es from amcat4.fields import create_fields, get_fields @@ -5,16 +6,24 @@ from amcat4.preprocessing.processor import get_manager -def get_instructions(index: str): +def get_instructions(index: str) -> Iterable[PreprocessingInstruction]: res = es().get(index=get_settings().system_index, id=index, source="preprocessing") - return res["_source"].get("preprocessing", []) + for i in res["_source"].get("preprocessing", []): + yield PreprocessingInstruction.model_validate(i) + + +def get_instruction(index: str, field: str) -> Optional[PreprocessingInstruction]: + for i in get_instructions(index): + if i.field == field: + return i def add_instruction(index: str, instruction: PreprocessingInstruction): if instruction.field in get_fields(index): raise ValueError("Field {instruction.field} already exists in index {index}") - current = get_instructions(index) - current.append(instruction.model_dump()) + instructions = list(get_instructions(index)) + instructions.append(instruction) create_fields(index, {instruction.field: "preprocess"}) - es().update(index=get_settings().system_index, id=index, doc=dict(preprocessing=current)) + body = [i.model_dump() for i in instructions] + es().update(index=get_settings().system_index, id=index, doc=dict(preprocessing=body)) get_manager().add_preprocessor(index, instruction) diff --git a/amcat4/preprocessing/models.py b/amcat4/preprocessing/models.py index eaa08ea..33b791c 100644 --- a/amcat4/preprocessing/models.py +++ b/amcat4/preprocessing/models.py @@ -33,15 +33,23 @@ def build_request(self, doc) -> httpx.Request: if not task.request.template: raise ValueError(f"Task {task.name} has json body but not template") body = copy.deepcopy(task.request.template) + headers = {} for argument in self.arguments: param = task.get_parameter(argument.name) if param.use_field == "yes": value = doc.get(argument.field) else: value = argument.value - param.parsed.update(body, value) - - return httpx.Request("POST", self.endpoint, json=body) + if param.header: + if ":" in param.path: + path, prefix = param.path.split(":", 1) + prefix = f"{prefix} " + else: + path, prefix = param.path, "" + headers[path] = f"{prefix}{value}" + else: + param.parsed.update(body, value) + return httpx.Request("POST", self.endpoint, json=body, headers=headers) def parse_output(self, output) -> Iterable[Tuple[str, Any]]: task = get_task(self.task) diff --git a/amcat4/preprocessing/processor.py b/amcat4/preprocessing/processor.py index 309136e..d45d557 100644 --- a/amcat4/preprocessing/processor.py +++ b/amcat4/preprocessing/processor.py @@ -2,19 +2,21 @@ from functools import cache import logging from typing import Dict, Tuple -from anyio import create_task_group -from httpx import AsyncClient +from elasticsearch import NotFoundError +from httpx import AsyncClient, HTTPStatusError from amcat4.elastic import es -from amcat4.index import refresh_index, update_document +from amcat4.index import list_known_indices, refresh_index, update_document from amcat4.preprocessing.models import PreprocessingInstruction +logger = logging.getLogger("amcat4.preprocessing") + class PreprocessorManager: SINGLETON = None def __init__(self): self.preprocessors: Dict[Tuple[str, str], asyncio.Task] = {} - self.task_group = create_task_group() + self.preprocessor_status: Dict[Tuple[str, str], str] = {} def add_preprocessor(self, index: str, instruction: PreprocessingInstruction): self.preprocessors[index, instruction.field] = asyncio.create_task(run_processor_loop(index, instruction)) @@ -22,21 +24,58 @@ def add_preprocessor(self, index: str, instruction: PreprocessingInstruction): def stop_preprocessor(self, index: str, field: str): self.preprocessors[index, field].cancel() + def stop_preprocessors(self, index: str): + tasks = list(self.preprocessors.items()) + for (ix, field), task in tasks: + if index == ix: + task.cancel() + del self.preprocessors[ix, field] + del self.preprocessor_status[ix, field] + def stop(self): for task in self.preprocessors.values(): task.cancel() + def get_status(self, index: str, field: str): + task = self.preprocessors.get((index, field)) + if not task: + return "Unknown" + if task.cancelled(): + return "Cancelled" + if task.done(): + return "Stopped" + return self.preprocessor_status.get((index, field), "Unknown status") + @cache def get_manager(): return PreprocessorManager() +def start_processors(): + import amcat4.preprocessing.instruction + + logger.info("Starting preprocessing loops (if needed)") + manager = get_manager() + for index in list_known_indices(): + try: + instructions = list(amcat4.preprocessing.instruction.get_instructions(index.id)) + except NotFoundError: + logging.warning(f"Index {index.id} does not exist!") + continue + for instruction in instructions: + manager.add_preprocessor(index.id, instruction) + + async def run_processor_loop(index, instruction: PreprocessingInstruction): + logger.info(f"Starting preprocessing loop for {index}.{instruction.field}") while True: - logging.info(f"Preprocessing documents for {index}.{instruction.field}") + logger.info(f"Preprocessing loop woke up for {index}.{instruction.field}") + get_manager().preprocessor_status[index, instruction.field] = "Active" await process_documents(index, instruction) - await asyncio.sleep(1) + get_manager().preprocessor_status[index, instruction.field] = "Sleeping" + logger.info(f"Preprocessing loop sleeping for {index}.{instruction.field}") + await asyncio.sleep(10) async def process_documents(index: str, instruction: PreprocessingInstruction, size=100): @@ -47,6 +86,7 @@ async def process_documents(index: str, instruction: PreprocessingInstruction, s # Q: it it better to repeat a simple "get n todo docs", or to iteratively scroll past all todo items? while True: docs = list(get_todo(index, instruction, size=size)) + logger.debug(f"Preprocessing for {index}.{instruction.field}: retrieved {len(docs)} docs to process") for doc in docs: await process_doc(index, instruction, doc) if len(docs) < size: @@ -54,18 +94,34 @@ async def process_documents(index: str, instruction: PreprocessingInstruction, s refresh_index(index) -def get_todo(index: str, instruction, size=100): +def get_todo(index: str, instruction: PreprocessingInstruction, size=100): fields = [arg.field for arg in instruction.arguments if arg.field] q = dict(bool=dict(must_not=dict(exists=dict(field=instruction.field)))) for doc in es().search(index=index, size=size, source_includes=fields, query=q)["hits"]["hits"]: yield {"_id": doc["_id"], **doc["_source"]} +def get_counts(index: str, field: str): + agg = dict(status=dict(terms=dict(field=f"{field}.status"))) + + res = es().search(index="test", size=0, aggs=agg) + result = dict(total=res["hits"]["total"]["value"]) + for bucket in res["aggregations"]["status"]["buckets"]: + result[bucket["key"]] = bucket["doc_count"] + return result + + async def process_doc(index: str, instruction: PreprocessingInstruction, doc: dict): # TODO catch errors and add to status field, rather than raising req = instruction.build_request(doc) - response = await AsyncClient().send(req) - response.raise_for_status() + try: + response = await AsyncClient().send(req) + response.raise_for_status() + except HTTPStatusError as e: + error = f"{e.response.status_code}: {e.response.text}" + logging.exception(f"Error on preprocessing {index}.{instruction.field} doc {doc['_id']}") + update_document(index, doc["_id"], {instruction.field: dict(status="error", error=error)}) + return result = dict(instruction.parse_output(response.json())) result[instruction.field] = dict(status="done") update_document(index, doc["_id"], result) diff --git a/amcat4/preprocessing/task.py b/amcat4/preprocessing/task.py index 4654a2e..915d505 100644 --- a/amcat4/preprocessing/task.py +++ b/amcat4/preprocessing/task.py @@ -29,6 +29,7 @@ class PreprocessingParameter(PreprocessingOutput): use_field: Literal["yes", "no"] = "no" default: Optional[bool | str | int | float] = None placeholder: Optional[str] = None + header: Optional[bool] = None class PreprocessingEndpoint(BaseModel): @@ -77,6 +78,13 @@ def get_output(self, name) -> PreprocessingOutput: placeholder="politics, sports", path="$.parameters.candidate_labels", ), + PreprocessingParameter( + name="Huggingface Token", + type="string", + use_field="no", + header=True, + path="Authorization:Bearer", + ), ], outputs=[PreprocessingOutput(name="label", path="$.labels[0]")], request=PreprocessingRequest(body="json", template={"inputs": "", "parameters": {"candidate_labels": ""}}), diff --git a/tests/test_api_preprocessing.py b/tests/test_api_preprocessing.py index 73aa25a..a05e72d 100644 --- a/tests/test_api_preprocessing.py +++ b/tests/test_api_preprocessing.py @@ -1,6 +1,8 @@ import asyncio +import json import pytest from amcat4.index import Role, get_document, refresh_index, set_role +from amcat4.preprocessing.models import PreprocessingInstruction from tests.conftest import TEST_DOCUMENTS from tests.test_preprocessing import INSTRUCTION from tests.tools import build_headers, check @@ -24,19 +26,21 @@ def test_auth(client, index, user): @pytest.mark.asyncio async def test_post_get_instructions(client, user, index_docs, httpx_mock): + i = PreprocessingInstruction.model_validate_json(json.dumps(INSTRUCTION)) + set_role(index_docs, user, Role.WRITER) res = client.get(f"/index/{index_docs}/preprocessing", headers=build_headers(user=user)) res.raise_for_status() assert len(res.json()) == 0 - httpx_mock.add_response(url=INSTRUCTION["endpoint"], json={"labels": ["games", "sports"], "scores": [0.9, 0.1]}) + httpx_mock.add_response(url=i.endpoint, json={"labels": ["games", "sports"], "scores": [0.9, 0.1]}) - res = client.post(f"/index/{index_docs}/preprocessing", headers=build_headers(user=user), json=INSTRUCTION) + res = client.post(f"/index/{index_docs}/preprocessing", headers=build_headers(user=user), json=i.model_dump()) res.raise_for_status() refresh_index(index_docs) res = client.get(f"/index/{index_docs}/preprocessing", headers=build_headers(user=user)) res.raise_for_status() - assert {item["field"] for item in res.json()} == {INSTRUCTION["field"]} + assert {item["field"] for item in res.json()} == {i.field} while len(httpx_mock.get_requests()) < len(TEST_DOCUMENTS): await asyncio.sleep(0.1) @@ -44,4 +48,4 @@ async def test_post_get_instructions(client, user, index_docs, httpx_mock): assert all(get_document(index_docs, doc["_id"])["class_label"] == "games" for doc in TEST_DOCUMENTS) # Cannot re-add the same field - check(client.post(f"/index/{index_docs}/preprocessing", json=INSTRUCTION, headers=build_headers(user=user)), 400) + check(client.post(f"/index/{index_docs}/preprocessing", json=i.model_dump(), headers=build_headers(user=user)), 400) From 020a31679a461d44210c31cdadab6f73e891e276 Mon Sep 17 00:00:00 2001 From: Wouter van Atteveldt Date: Thu, 18 Apr 2024 15:18:19 +1000 Subject: [PATCH 70/80] Add image classificaiton task --- amcat4/multimedia.py | 9 +++++++ amcat4/preprocessing/models.py | 41 +++++++++++++++++++++++++------ amcat4/preprocessing/processor.py | 2 +- amcat4/preprocessing/task.py | 26 +++++++++++++++++--- 4 files changed, 67 insertions(+), 11 deletions(-) diff --git a/amcat4/multimedia.py b/amcat4/multimedia.py index 92060bd..d6fda86 100644 --- a/amcat4/multimedia.py +++ b/amcat4/multimedia.py @@ -76,6 +76,15 @@ def stat_multimedia_object(index: str, key: str) -> Object: return minio.stat_object(bucket, key) +def get_multimedia_object(index: str, key: str) -> bytes: + minio = get_minio() + bucket = get_bucket(minio, index, create_if_needed=False) + if not bucket: + raise ValueError(f"Bucket for {index} does not exist") + res = minio.get_object(bucket, key) + return res.read() + + def delete_bucket(minio: Minio, index: str): bucket = get_bucket(minio, index, create_if_needed=False) if not bucket: diff --git a/amcat4/preprocessing/models.py b/amcat4/preprocessing/models.py index 33b791c..babd79a 100644 --- a/amcat4/preprocessing/models.py +++ b/amcat4/preprocessing/models.py @@ -1,9 +1,12 @@ import copy +from multiprocessing import Value from typing import Any, Iterable, List, Optional, Tuple import httpx from pydantic import BaseModel +from amcat4 import multimedia +from amcat4.fields import get_fields from amcat4.preprocessing.task import get_task @@ -25,22 +28,32 @@ class PreprocessingInstruction(BaseModel): arguments: List[PreprocessingArgument] outputs: List[PreprocessingOutput] - def build_request(self, doc) -> httpx.Request: + def build_request(self, index, doc) -> httpx.Request: # TODO: validate that instruction is valid for task! + fields = get_fields(index) task = get_task(self.task) - if task.request.body != "json": + if task.request.body == "json": + if not task.request.template: + raise ValueError(f"Task {task.name} has json body but not template") + body = copy.deepcopy(task.request.template) + elif task.request.body == "binary": + body = None + else: raise NotImplementedError() - if not task.request.template: - raise ValueError(f"Task {task.name} has json body but not template") - body = copy.deepcopy(task.request.template) headers = {} for argument in self.arguments: param = task.get_parameter(argument.name) if param.use_field == "yes": + if not argument.field: + raise ValueError("Field not given for field param") value = doc.get(argument.field) + if task.request.body == "binary" and fields[argument.field].type in ["image"]: + value = multimedia.get_multimedia_object(index, value) else: value = argument.value if param.header: + if not param.path: + raise ValueError("Path required for header params") if ":" in param.path: path, prefix = param.path.split(":", 1) prefix = f"{prefix} " @@ -48,8 +61,22 @@ def build_request(self, doc) -> httpx.Request: path, prefix = param.path, "" headers[path] = f"{prefix}{value}" else: - param.parsed.update(body, value) - return httpx.Request("POST", self.endpoint, json=body, headers=headers) + if task.request.body == "json": + if not param.path: + raise ValueError("Path required for json body params") + param.parsed.update(body, value) + elif task.request.body == "binary": + if param.path: + raise ValueError("Path not allowed for binary body") + if body: + raise ValueError("Multiple values for body") + if type(value) != bytes: + raise ValueError("Binary request requires multimedia object") + body = value + if task.request.body == "json": + return httpx.Request("POST", self.endpoint, json=body, headers=headers) + else: + return httpx.Request("POST", self.endpoint, content=body, headers=headers) def parse_output(self, output) -> Iterable[Tuple[str, Any]]: task = get_task(self.task) diff --git a/amcat4/preprocessing/processor.py b/amcat4/preprocessing/processor.py index d45d557..8a8e06d 100644 --- a/amcat4/preprocessing/processor.py +++ b/amcat4/preprocessing/processor.py @@ -113,7 +113,7 @@ def get_counts(index: str, field: str): async def process_doc(index: str, instruction: PreprocessingInstruction, doc: dict): # TODO catch errors and add to status field, rather than raising - req = instruction.build_request(doc) + req = instruction.build_request(index, doc) try: response = await AsyncClient().send(req) response.raise_for_status() diff --git a/amcat4/preprocessing/task.py b/amcat4/preprocessing/task.py index 915d505..56558d9 100644 --- a/amcat4/preprocessing/task.py +++ b/amcat4/preprocessing/task.py @@ -12,13 +12,13 @@ class PreprocessingRequest(BaseModel): body: Literal["json", "binary"] - template: Optional[dict] + template: Optional[dict] = None class PreprocessingOutput(BaseModel): name: str type: str = "string" - path: str + path: Optional[str] = None @functools.cached_property def parsed(self) -> jsonpath_ng.JSONPath: @@ -88,7 +88,27 @@ def get_output(self, name) -> PreprocessingOutput: ], outputs=[PreprocessingOutput(name="label", path="$.labels[0]")], request=PreprocessingRequest(body="json", template={"inputs": "", "parameters": {"candidate_labels": ""}}), - ) + ), + PreprocessingTask( + # https://huggingface.co/docs/api-inference/detailed_parameters#zero-shot-classification-task + name="HuggingFace Image Classification", + endpoint=PreprocessingEndpoint( + placeholder="https://api-inference.huggingface.co/models/google/vit-base-patch16-224", + domain=["huggingface.co", "huggingfacecloud.com"], + ), + parameters=[ + PreprocessingParameter(name="input", type="image", use_field="yes"), + PreprocessingParameter( + name="Huggingface Token", + type="string", + use_field="no", + header=True, + path="Authorization:Bearer", + ), + ], + outputs=[PreprocessingOutput(name="label", path="$[0].label")], + request=PreprocessingRequest(body="binary"), + ), ] From e065f6812622e1e64413ca5c28af86de5d3f5560 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Fri, 19 Apr 2024 01:20:01 +0200 Subject: [PATCH 71/80] added recommended output field type and secret option for parameters --- amcat4/preprocessing/instruction.py | 6 +++++- amcat4/preprocessing/models.py | 1 + amcat4/preprocessing/task.py | 18 ++++++++++++++---- setup.py | 2 +- 4 files changed, 21 insertions(+), 6 deletions(-) diff --git a/amcat4/preprocessing/instruction.py b/amcat4/preprocessing/instruction.py index 6671d71..3571894 100644 --- a/amcat4/preprocessing/instruction.py +++ b/amcat4/preprocessing/instruction.py @@ -9,6 +9,9 @@ def get_instructions(index: str) -> Iterable[PreprocessingInstruction]: res = es().get(index=get_settings().system_index, id=index, source="preprocessing") for i in res["_source"].get("preprocessing", []): + for a in i.get("arguments", []): + if a.get("secret"): + a["value"] = "********" yield PreprocessingInstruction.model_validate(i) @@ -19,8 +22,9 @@ def get_instruction(index: str, field: str) -> Optional[PreprocessingInstruction def add_instruction(index: str, instruction: PreprocessingInstruction): + print(instruction) if instruction.field in get_fields(index): - raise ValueError("Field {instruction.field} already exists in index {index}") + raise ValueError(f"Field {instruction.field} already exists in index {index}") instructions = list(get_instructions(index)) instructions.append(instruction) create_fields(index, {instruction.field: "preprocess"}) diff --git a/amcat4/preprocessing/models.py b/amcat4/preprocessing/models.py index babd79a..55f25e0 100644 --- a/amcat4/preprocessing/models.py +++ b/amcat4/preprocessing/models.py @@ -14,6 +14,7 @@ class PreprocessingArgument(BaseModel): name: str field: Optional[str] = None value: Optional[str | int | bool | float | List[str] | List[int] | List[float]] = None + secret: Optional[bool] = False class PreprocessingOutput(BaseModel): diff --git a/amcat4/preprocessing/task.py b/amcat4/preprocessing/task.py index 56558d9..1f00712 100644 --- a/amcat4/preprocessing/task.py +++ b/amcat4/preprocessing/task.py @@ -1,9 +1,12 @@ import functools from multiprocessing import Value +from re import I from typing import Any, Dict, List, Literal, Optional from pydantic import BaseModel import jsonpath_ng +from amcat4.models import FieldType + """ https://huggingface.co/docs/api-inference/detailed_parameters @@ -15,7 +18,7 @@ class PreprocessingRequest(BaseModel): template: Optional[dict] = None -class PreprocessingOutput(BaseModel): +class PreprocessingSetting(BaseModel): name: str type: str = "string" path: Optional[str] = None @@ -25,11 +28,16 @@ def parsed(self) -> jsonpath_ng.JSONPath: return jsonpath_ng.parse(self.path) -class PreprocessingParameter(PreprocessingOutput): +class PreprocessingOutput(PreprocessingSetting): + recommended_type: FieldType + + +class PreprocessingParameter(PreprocessingSetting): use_field: Literal["yes", "no"] = "no" default: Optional[bool | str | int | float] = None placeholder: Optional[str] = None header: Optional[bool] = None + secret: Optional[bool] = False class PreprocessingEndpoint(BaseModel): @@ -84,9 +92,10 @@ def get_output(self, name) -> PreprocessingOutput: use_field="no", header=True, path="Authorization:Bearer", + secret=True, ), ], - outputs=[PreprocessingOutput(name="label", path="$.labels[0]")], + outputs=[PreprocessingOutput(name="label", recommended_type="keyword", path="$.labels[0]")], request=PreprocessingRequest(body="json", template={"inputs": "", "parameters": {"candidate_labels": ""}}), ), PreprocessingTask( @@ -104,9 +113,10 @@ def get_output(self, name) -> PreprocessingOutput: use_field="no", header=True, path="Authorization:Bearer", + secret=True, ), ], - outputs=[PreprocessingOutput(name="label", path="$[0].label")], + outputs=[PreprocessingOutput(name="label", recommended_type="keyword", path="$[0].label")], request=PreprocessingRequest(body="binary"), ), ] diff --git a/setup.py b/setup.py index d19636e..52b5edb 100644 --- a/setup.py +++ b/setup.py @@ -34,6 +34,6 @@ "minio", "jsonpath_ng", ], - extras_require={"dev": ["pytest", "mypy", "flake8", "responses", "pre-commit", "types-requests"]}, + extras_require={"dev": ["pytest", "pytest-httpx", "mypy", "flake8", "responses", "pre-commit", "types-requests"]}, entry_points={"console_scripts": ["amcat4 = amcat4.__main__:main"]}, ) From ffed41844d37a959b4dd24572338de35009c4f75 Mon Sep 17 00:00:00 2001 From: Wouter van Atteveldt Date: Fri, 19 Apr 2024 14:55:43 +1000 Subject: [PATCH 72/80] Better error handling --- amcat4/preprocessing/processor.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/amcat4/preprocessing/processor.py b/amcat4/preprocessing/processor.py index 8a8e06d..bbf29d3 100644 --- a/amcat4/preprocessing/processor.py +++ b/amcat4/preprocessing/processor.py @@ -70,11 +70,14 @@ def start_processors(): async def run_processor_loop(index, instruction: PreprocessingInstruction): logger.info(f"Starting preprocessing loop for {index}.{instruction.field}") while True: - logger.info(f"Preprocessing loop woke up for {index}.{instruction.field}") - get_manager().preprocessor_status[index, instruction.field] = "Active" - await process_documents(index, instruction) - get_manager().preprocessor_status[index, instruction.field] = "Sleeping" - logger.info(f"Preprocessing loop sleeping for {index}.{instruction.field}") + try: + logger.info(f"Preprocessing loop woke up for {index}.{instruction.field}") + get_manager().preprocessor_status[index, instruction.field] = "Active" + await process_documents(index, instruction) + get_manager().preprocessor_status[index, instruction.field] = "Sleeping" + logger.info(f"Preprocessing loop sleeping for {index}.{instruction.field}") + except Exception: + logger.exception(f"Error on preprocessing {index}.{instruction.field}") await asyncio.sleep(10) @@ -104,7 +107,7 @@ def get_todo(index: str, instruction: PreprocessingInstruction, size=100): def get_counts(index: str, field: str): agg = dict(status=dict(terms=dict(field=f"{field}.status"))) - res = es().search(index="test", size=0, aggs=agg) + res = es().search(index=index, size=0, aggs=agg) result = dict(total=res["hits"]["total"]["value"]) for bucket in res["aggregations"]["status"]["buckets"]: result[bucket["key"]] = bucket["doc_count"] @@ -113,7 +116,11 @@ def get_counts(index: str, field: str): async def process_doc(index: str, instruction: PreprocessingInstruction, doc: dict): # TODO catch errors and add to status field, rather than raising - req = instruction.build_request(index, doc) + try: + req = instruction.build_request(index, doc) + except Exception as e: + logging.exception(f"Error on preprocessing {index}.{instruction.field} doc {doc['_id']}") + update_document(index, doc["_id"], {instruction.field: dict(status="error", error=str(e))}) try: response = await AsyncClient().send(req) response.raise_for_status() From a99325c06eab0dad09a547d63a26d3c2d873e467 Mon Sep 17 00:00:00 2001 From: Wouter van Atteveldt Date: Fri, 19 Apr 2024 15:48:21 +1000 Subject: [PATCH 73/80] Add TLS option for minio --- amcat4/config.py | 1 + amcat4/multimedia.py | 7 ++++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/amcat4/config.py b/amcat4/config.py index 5364df8..adb8285 100644 --- a/amcat4/config.py +++ b/amcat4/config.py @@ -111,6 +111,7 @@ class Settings(BaseSettings): ] = None minio_host: Annotated[str | None, Field()] = None + minio_tls: Annotated[bool, Field()] = False minio_access_key: Annotated[str | None, Field()] = None minio_secret_key: Annotated[str | None, Field()] = None diff --git a/amcat4/multimedia.py b/amcat4/multimedia.py index d6fda86..5266b5f 100644 --- a/amcat4/multimedia.py +++ b/amcat4/multimedia.py @@ -38,7 +38,12 @@ def _connect_minio() -> Optional[Minio]: return None if settings.minio_secret_key is None or settings.minio_access_key is None: raise ValueError("minio_access_key or minio_secret_key not specified") - return Minio(settings.minio_host, secure=False, access_key=settings.minio_access_key, secret_key=settings.minio_secret_key) + return Minio( + settings.minio_host, + secure=settings.minio_tls, + access_key=settings.minio_access_key, + secret_key=settings.minio_secret_key, + ) def bucket_name(index: str) -> str: From ab51771bd07ce3a67bfe850997399f14fbccd7db Mon Sep 17 00:00:00 2001 From: Wouter van Atteveldt Date: Fri, 19 Apr 2024 15:52:11 +1000 Subject: [PATCH 74/80] add https for presigned post --- amcat4/multimedia.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/amcat4/multimedia.py b/amcat4/multimedia.py index 5266b5f..58c016c 100644 --- a/amcat4/multimedia.py +++ b/amcat4/multimedia.py @@ -113,7 +113,7 @@ def presigned_post(index: str, key_prefix: str = "", days_valid=1): bucket = get_bucket(minio, index) policy = PostPolicy(bucket, expiration=datetime.datetime.now() + datetime.timedelta(days=days_valid)) policy.add_starts_with_condition("key", key_prefix) - url = f"http://{get_settings().minio_host}/{bucket}" + url = f"http{'s' if get_settings().minio_tls else ''}://{get_settings().minio_host}/{bucket}" return url, minio.presigned_post_policy(policy) From fd1739b1cbdc9361e44cb2534fd1170ca995a9b4 Mon Sep 17 00:00:00 2001 From: Kasper Welbers Date: Fri, 19 Apr 2024 10:11:01 +0200 Subject: [PATCH 75/80] return on exception --- amcat4/preprocessing/processor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/amcat4/preprocessing/processor.py b/amcat4/preprocessing/processor.py index bbf29d3..d1d0601 100644 --- a/amcat4/preprocessing/processor.py +++ b/amcat4/preprocessing/processor.py @@ -121,6 +121,7 @@ async def process_doc(index: str, instruction: PreprocessingInstruction, doc: di except Exception as e: logging.exception(f"Error on preprocessing {index}.{instruction.field} doc {doc['_id']}") update_document(index, doc["_id"], {instruction.field: dict(status="error", error=str(e))}) + return try: response = await AsyncClient().send(req) response.raise_for_status() From e05d0faa19d47d13d7b09671c63ae55416a5d4e3 Mon Sep 17 00:00:00 2001 From: Wouter van Atteveldt Date: Tue, 23 Apr 2024 13:54:31 +1000 Subject: [PATCH 76/80] Update preprocessing: stop on done, restart on upload, allow restart errors --- amcat4/api/preprocessing.py | 3 +- amcat4/fields.py | 1 - amcat4/index.py | 44 ++++++++- amcat4/preprocessing/instruction.py | 33 ------- amcat4/preprocessing/processor.py | 145 ++++++++++++++++++---------- tests/test_preprocessing.py | 121 ++++++++++++++++++++--- 6 files changed, 245 insertions(+), 102 deletions(-) delete mode 100644 amcat4/preprocessing/instruction.py diff --git a/amcat4/api/preprocessing.py b/amcat4/api/preprocessing.py index d35fe6e..a035eec 100644 --- a/amcat4/api/preprocessing.py +++ b/amcat4/api/preprocessing.py @@ -3,7 +3,8 @@ from amcat4 import index from amcat4.api.auth import authenticated_user, check_role -from amcat4.preprocessing.instruction import PreprocessingInstruction, get_instruction, get_instructions, add_instruction +from amcat4.preprocessing.models import PreprocessingInstruction +from amcat4.index import get_instruction, get_instructions, add_instruction from amcat4.preprocessing.processor import get_counts, get_manager from amcat4.preprocessing.task import get_tasks diff --git a/amcat4/fields.py b/amcat4/fields.py index 1a5717a..8ed86cc 100644 --- a/amcat4/fields.py +++ b/amcat4/fields.py @@ -215,7 +215,6 @@ def create_fields(index: str, fields: Mapping[str, FieldType | CreateField]): if len(mapping) > 0: # if there are new identifiers, check whether this is allowed first - print(json.dumps(mapping, indent=2)) es().indices.put_mapping(index=index, properties=mapping) es().update( index=get_settings().system_index, diff --git a/amcat4/index.py b/amcat4/index.py index 72332e2..49b3ea7 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -56,6 +56,8 @@ get_fields, ) from amcat4.models import CreateField, Field, FieldType +from amcat4.preprocessing.models import PreprocessingInstruction +from amcat4.preprocessing import processor class Role(IntEnum): @@ -228,7 +230,7 @@ def deregister_index(index: str, ignore_missing=False) -> None: # Stop preprocessing loops on this index from amcat4.preprocessing.processor import get_manager - get_manager().stop_preprocessors(index) + get_manager().remove_index_preprocessors(index) def _roles_from_elastic(roles: list[dict]) -> dict[str, Role]: @@ -501,6 +503,9 @@ def es_actions(index, documents, op_type): e.args = e.args + (f"First error: {reason}",) raise + # Start preprocessors for this index (if any) + processor.get_manager().start_index_preprocessors(index) + return dict(successes=successes, failures=failures) @@ -543,7 +548,7 @@ def update_by_query(index: str | list[str], script: str, query: dict, params: di return dict(updated=result["updated"], total=result["total"]) -TAG_SCRIPTS = dict( +UDATE_SCRIPTS = dict( add=""" if (ctx._source[params.field] == null) { ctx._source[params.field] = [params.tag] @@ -570,6 +575,39 @@ def update_by_query(index: str | list[str], script: str, query: dict, params: di def update_tag_by_query(index: str | list[str], action: Literal["add", "remove"], query: dict, field: str, tag: str): create_or_verify_tag_field(index, field) - script = TAG_SCRIPTS[action] + script = UDATE_SCRIPTS[action] params = dict(field=field, tag=tag) return update_by_query(index, script, query, params) + + +def get_instructions(index: str) -> Iterable[PreprocessingInstruction]: + res = es().get(index=get_settings().system_index, id=index, source="preprocessing") + for i in res["_source"].get("preprocessing", []): + for a in i.get("arguments", []): + if a.get("secret"): + a["value"] = "********" + yield PreprocessingInstruction.model_validate(i) + + +def get_instruction(index: str, field: str) -> Optional[PreprocessingInstruction]: + for i in get_instructions(index): + if i.field == field: + return i + + +def add_instruction(index: str, instruction: PreprocessingInstruction): + if instruction.field in get_fields(index): + raise ValueError(f"Field {instruction.field} already exists in index {index}") + instructions = list(get_instructions(index)) + instructions.append(instruction) + create_fields(index, {instruction.field: "preprocess"}) + body = [i.model_dump() for i in instructions] + es().update(index=get_settings().system_index, id=index, doc=dict(preprocessing=body)) + processor.get_manager().add_preprocessor(index, instruction) + + +def reassign_preprocessing_errors(index: str, field: str): + """Reset status for any documents with error status, and restart preprocessor""" + query = dict(query=dict(term={f"{field}.status": dict(value="error")})) + update_by_query(index, "ctx._source[params.field] = null", query, dict(field=field)) + processor.get_manager().start_preprocessor(index, field) diff --git a/amcat4/preprocessing/instruction.py b/amcat4/preprocessing/instruction.py deleted file mode 100644 index 3571894..0000000 --- a/amcat4/preprocessing/instruction.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Iterable, Optional -from amcat4.config import get_settings -from amcat4.elastic import es -from amcat4.fields import create_fields, get_fields -from amcat4.preprocessing.models import PreprocessingInstruction -from amcat4.preprocessing.processor import get_manager - - -def get_instructions(index: str) -> Iterable[PreprocessingInstruction]: - res = es().get(index=get_settings().system_index, id=index, source="preprocessing") - for i in res["_source"].get("preprocessing", []): - for a in i.get("arguments", []): - if a.get("secret"): - a["value"] = "********" - yield PreprocessingInstruction.model_validate(i) - - -def get_instruction(index: str, field: str) -> Optional[PreprocessingInstruction]: - for i in get_instructions(index): - if i.field == field: - return i - - -def add_instruction(index: str, instruction: PreprocessingInstruction): - print(instruction) - if instruction.field in get_fields(index): - raise ValueError(f"Field {instruction.field} already exists in index {index}") - instructions = list(get_instructions(index)) - instructions.append(instruction) - create_fields(index, {instruction.field: "preprocess"}) - body = [i.model_dump() for i in instructions] - es().update(index=get_settings().system_index, id=index, doc=dict(preprocessing=body)) - get_manager().add_preprocessor(index, instruction) diff --git a/amcat4/preprocessing/processor.py b/amcat4/preprocessing/processor.py index d1d0601..a1b5f07 100644 --- a/amcat4/preprocessing/processor.py +++ b/amcat4/preprocessing/processor.py @@ -1,50 +1,82 @@ import asyncio from functools import cache import logging -from typing import Dict, Tuple +from typing import Dict, Literal, Tuple from elasticsearch import NotFoundError from httpx import AsyncClient, HTTPStatusError from amcat4.elastic import es -from amcat4.index import list_known_indices, refresh_index, update_document +import amcat4.index from amcat4.preprocessing.models import PreprocessingInstruction logger = logging.getLogger("amcat4.preprocessing") +PreprocessorStatus = Literal["Active", "Paused", "Unknown", "Error", "Stopped"] + + +class RateLimit(Exception): + pass + + +PAUSE_ON_RATE_LIMIT_SECONDS = 10 + class PreprocessorManager: SINGLETON = None def __init__(self): - self.preprocessors: Dict[Tuple[str, str], asyncio.Task] = {} - self.preprocessor_status: Dict[Tuple[str, str], str] = {} + self.preprocessors: Dict[Tuple[str, str], PreprocessingInstruction] = {} + self.running_tasks: Dict[Tuple[str, str], asyncio.Task] = {} + self.preprocessor_status: Dict[Tuple[str, str], PreprocessorStatus] = {} + + def set_status(self, index: str, field: str, status: PreprocessorStatus): + self.preprocessor_status[index, field] = status def add_preprocessor(self, index: str, instruction: PreprocessingInstruction): - self.preprocessors[index, instruction.field] = asyncio.create_task(run_processor_loop(index, instruction)) + """Start a new preprocessor task and add it to the manager, returning the Task object""" + self.preprocessors[index, instruction.field] = instruction + self.start_preprocessor(index, instruction.field) + + def start_preprocessor(self, index: str, field: str): + if existing_task := self.running_tasks.get((index, field)): + if not existing_task.done: + return existing_task + instruction = self.preprocessors[index, field] + self.running_tasks[index, instruction.field] = asyncio.create_task(run_processor_loop(index, instruction)) + + def start_index_preprocessors(self, index: str): + for ix, field in self.preprocessors: + if ix == index: + self.start_preprocessor(index, field) def stop_preprocessor(self, index: str, field: str): - self.preprocessors[index, field].cancel() - - def stop_preprocessors(self, index: str): - tasks = list(self.preprocessors.items()) - for (ix, field), task in tasks: - if index == ix: + """Stop a preprocessor task""" + try: + if task := self.running_tasks.get((index, field)): task.cancel() - del self.preprocessors[ix, field] - del self.preprocessor_status[ix, field] + except: + logging.exception(f"Error on cancelling preprocessor {index}:{field}") + + def remove_preprocessor(self, index: str, field: str): + """Stop this preprocessor remove them from the manager""" + self.stop_preprocessor(index, field) + del self.preprocessor_status[index, field] + del self.preprocessors[index, field] + + def remove_index_preprocessors(self, index: str): + """Stop all preprocessors on this index and remove them from the manager""" + for ix, field in list(self.preprocessors.keys()): + if index == ix: + self.remove_preprocessor(ix, field) - def stop(self): - for task in self.preprocessors.values(): - task.cancel() + def shutdown(self): + for task in self.running_tasks.values(): + try: + task.cancel() + except: + logging.exception(f"Error on cancelling preprocessor") - def get_status(self, index: str, field: str): - task = self.preprocessors.get((index, field)) - if not task: - return "Unknown" - if task.cancelled(): - return "Cancelled" - if task.done(): - return "Stopped" - return self.preprocessor_status.get((index, field), "Unknown status") + def get_status(self, index: str, field: str) -> PreprocessorStatus: + return self.preprocessor_status.get((index, field), "Unknown") @cache @@ -53,13 +85,11 @@ def get_manager(): def start_processors(): - import amcat4.preprocessing.instruction - logger.info("Starting preprocessing loops (if needed)") manager = get_manager() - for index in list_known_indices(): + for index in amcat4.index.list_known_indices(): try: - instructions = list(amcat4.preprocessing.instruction.get_instructions(index.id)) + instructions = list(amcat4.index.get_instructions(index.id)) except NotFoundError: logging.warning(f"Index {index.id} does not exist!") continue @@ -68,33 +98,44 @@ def start_processors(): async def run_processor_loop(index, instruction: PreprocessingInstruction): + """ + Main preprocessor loop. + Calls process_documents to process a batch of documents, until 'done' + """ + # TODO: add logic for pausing processing on hitting rate limit logger.info(f"Starting preprocessing loop for {index}.{instruction.field}") - while True: + get_manager().set_status(index, instruction.field, "Active") + done = False + while not done: try: - logger.info(f"Preprocessing loop woke up for {index}.{instruction.field}") - get_manager().preprocessor_status[index, instruction.field] = "Active" - await process_documents(index, instruction) - get_manager().preprocessor_status[index, instruction.field] = "Sleeping" - logger.info(f"Preprocessing loop sleeping for {index}.{instruction.field}") + done = await process_documents(index, instruction) + except RateLimit: + logger.info(f"Pausing preprocessing loop for {index}.{instruction.field}") + get_manager().set_status(index, instruction.field, "Paused") + await asyncio.sleep(PAUSE_ON_RATE_LIMIT_SECONDS) except Exception: logger.exception(f"Error on preprocessing {index}.{instruction.field}") - await asyncio.sleep(10) + get_manager().set_status(index, instruction.field, "Error") + return + get_manager().set_status(index, instruction.field, "Stopped") + logger.info(f"Stopping preprocessing loop for {index}.{instruction.field}") async def process_documents(index: str, instruction: PreprocessingInstruction, size=100): """ - Process all currently to-do documents in the index for this instruction. - Returns when it runs out of documents to do + Process a batch of currently to-do documents in the index for this instruction. + Return value indicates job completion: + It returns True when it runs out of documents to do, or False if there might be more documents. """ - # Q: it it better to repeat a simple "get n todo docs", or to iteratively scroll past all todo items? - while True: - docs = list(get_todo(index, instruction, size=size)) - logger.debug(f"Preprocessing for {index}.{instruction.field}: retrieved {len(docs)} docs to process") - for doc in docs: - await process_doc(index, instruction, doc) - if len(docs) < size: - return - refresh_index(index) + # Refresh index before getting new documents to make sure status updates are reflected + amcat4.index.refresh_index(index) + docs = list(get_todo(index, instruction, size=size)) + if not docs: + return True + logger.debug(f"Preprocessing for {index}.{instruction.field}: retrieved {len(docs)} docs to process") + for doc in docs: + await process_doc(index, instruction, doc) + return False def get_todo(index: str, instruction: PreprocessingInstruction, size=100): @@ -120,16 +161,18 @@ async def process_doc(index: str, instruction: PreprocessingInstruction, doc: di req = instruction.build_request(index, doc) except Exception as e: logging.exception(f"Error on preprocessing {index}.{instruction.field} doc {doc['_id']}") - update_document(index, doc["_id"], {instruction.field: dict(status="error", error=str(e))}) + amcat4.index.update_document(index, doc["_id"], {instruction.field: dict(status="error", error=str(e))}) return try: response = await AsyncClient().send(req) response.raise_for_status() except HTTPStatusError as e: - error = f"{e.response.status_code}: {e.response.text}" + if e.response.status_code == 503: + raise RateLimit(e) logging.exception(f"Error on preprocessing {index}.{instruction.field} doc {doc['_id']}") - update_document(index, doc["_id"], {instruction.field: dict(status="error", error=error)}) + body = dict(status="error", status_code=e.response.status_code, response=e.response.text) + amcat4.index.update_document(index, doc["_id"], {instruction.field: body}) return result = dict(instruction.parse_output(response.json())) result[instruction.field] = dict(status="done") - update_document(index, doc["_id"], result) + amcat4.index.update_document(index, doc["_id"], result) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 984909b..5cabe42 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -1,14 +1,22 @@ import asyncio import time +import httpx from pytest_httpx import HTTPXMock import json import pytest from amcat4.fields import create_fields -from amcat4.index import get_document, refresh_index -from amcat4.preprocessing.instruction import PreprocessingInstruction, add_instruction +from amcat4.index import get_document, reassign_preprocessing_errors, refresh_index, upload_documents, add_instruction +from amcat4.preprocessing.models import PreprocessingInstruction +from amcat4.preprocessing import processor -from amcat4.preprocessing.processor import PreprocessorManager, get_todo, process_doc, process_documents +from amcat4.preprocessing.processor import ( + get_counts, + get_manager, + get_todo, + process_doc, + process_documents, +) from tests.conftest import TEST_DOCUMENTS INSTRUCTION = dict( @@ -20,10 +28,10 @@ ) -def test_build_request(): +def test_build_request(index): i = PreprocessingInstruction.model_validate_json(json.dumps(INSTRUCTION)) doc = dict(text="Sample text") - req = i.build_request(doc) + req = i.build_request(index, doc) assert req.url == INSTRUCTION["endpoint"] assert json.loads(req.content) == dict(inputs=doc["text"], parameters=dict(candidate_labels=["politics", "sports"])) @@ -37,39 +45,126 @@ def test_parse_result(): @pytest.mark.asyncio async def test_preprocess(index_docs, httpx_mock: HTTPXMock): + """Test logic of process_doc and get_todo calls""" i = PreprocessingInstruction.model_validate_json(json.dumps(INSTRUCTION)) + httpx_mock.add_response(url=i.endpoint, json={"labels": ["politics", "sports"], "scores": [0.9, 0.1]}) + + # Create a preprocess fields. There should now be |docs| todo create_fields(index_docs, {i.field: "preprocess"}) todos = list(get_todo(index_docs, i)) assert all(set(todo.keys()) == {"_id", "text"} for todo in todos) assert {doc["_id"] for doc in todos} == {str(doc["_id"]) for doc in TEST_DOCUMENTS} + # Process a single document. Check that it's done, and that the todo list is now one shorter todo = sorted(todos, key=lambda todo: todo["_id"])[0] - httpx_mock.add_response(url=i.endpoint, json={"labels": ["politics", "sports"], "scores": [0.9, 0.1]}) await process_doc(index_docs, i, todo) doc = get_document(index_docs, todo["_id"]) assert doc[i.field] == {"status": "done"} assert doc["class_label"] == "politics" - refresh_index(index_docs) todos = list(get_todo(index_docs, i)) assert {doc["_id"] for doc in todos} == {str(doc["_id"]) for doc in TEST_DOCUMENTS} - {todo["_id"]} - # run all preprocessors in a loop - await process_documents(index_docs, i, size=2) + # run a single preprocessing loop, check that done is False and that + done = await process_documents(index_docs, i, size=2) + assert done == False + refresh_index(index_docs) + todos = list(get_todo(index_docs, i)) + assert len(todos) == len(TEST_DOCUMENTS) - (2 + 1) + + # run preprocessing until it returns done = True + while not done: + done = await process_documents(index_docs, i, size=2) + + # Todo should be empty, and there should be one call per document! refresh_index(index_docs) todos = list(get_todo(index_docs, i)) assert len(todos) == 0 - # There should be one call per document! assert len(httpx_mock.get_requests()) == len(TEST_DOCUMENTS) @pytest.mark.asyncio async def test_preprocess_loop(index_docs, httpx_mock: HTTPXMock): + """Test that adding an instruction automatically processes all docs in an index""" i = PreprocessingInstruction.model_validate_json(json.dumps(INSTRUCTION)) httpx_mock.add_response(url=i.endpoint, json={"labels": ["politics", "sports"], "scores": [0.9, 0.1]}) add_instruction(index_docs, i) - while len(httpx_mock.get_requests()) < len(TEST_DOCUMENTS): - await asyncio.sleep(0.1) - await asyncio.sleep(0.5) + await get_manager().running_tasks[index_docs, i.field] assert len(httpx_mock.get_requests()) == len(TEST_DOCUMENTS) assert all(get_document(index_docs, doc["_id"])["class_label"] == "politics" for doc in TEST_DOCUMENTS) + + +@pytest.mark.asyncio +async def test_preprocess_logic(index, httpx_mock: HTTPXMock): + """Test that main processing loop works correctly""" + i = PreprocessingInstruction.model_validate_json(json.dumps(INSTRUCTION)) + + async def mock_slow_response(_request) -> httpx.Response: + await asyncio.sleep(0.5) + return httpx.Response(json={"labels": ["politics"], "scores": [1]}, status_code=200) + + httpx_mock.add_callback(mock_slow_response, url=i.endpoint) + + # Add the instruction. Since there are no documents, it should return instantly-ish + add_instruction(index, i) + await asyncio.sleep(0.1) + assert get_manager().get_status(index, i.field) == "Stopped" + + # Add a document. The task should be re-activated and take half a second to complete + upload_documents(index, [{"text": "text"}], fields={"text": "text"}) + await asyncio.sleep(0.1) + assert get_manager().get_status(index, i.field) == "Active" + await asyncio.sleep(0.5) + assert get_manager().get_status(index, i.field) == "Stopped" + + +@pytest.mark.asyncio +async def test_preprocess_ratelimit(index_docs, httpx_mock: HTTPXMock): + """Test that processing is paused on hitting rate limit, and restarts automatically""" + i = PreprocessingInstruction.model_validate_json(json.dumps(INSTRUCTION)) + httpx_mock.add_response(url=i.endpoint, status_code=503) + + # Set a low pause time for the test + processor.PAUSE_ON_RATE_LIMIT_SECONDS = 0.5 + + # Start the async preprocessing loop. Receiving a 503 it should sleep for and retry + add_instruction(index_docs, i) + await asyncio.sleep(0.1) + assert get_manager().get_status(index_docs, i.field) == "Paused" + + # Now mock a success response and wait for .5 seconds + httpx_mock.reset(assert_all_responses_were_requested=True) + httpx_mock.add_response(url=i.endpoint, json={"labels": ["politics", "sports"], "scores": [0.9, 0.1]}) + await asyncio.sleep(0.5) + assert get_manager().get_status(index_docs, i.field) == "Stopped" + + +@pytest.mark.asyncio +async def test_preprocess_error(index_docs, httpx_mock: HTTPXMock): + """Test that errors are reported correctly""" + i = PreprocessingInstruction.model_validate_json(json.dumps(INSTRUCTION)) + + def some_errors(request): + input = json.loads(request.content)["inputs"] + if "text" in input: # should be true for 2 documents + return httpx.Response(json={"error": "I'm a teapot!"}, status_code=418) + else: + return httpx.Response(json={"labels": ["politics"], "scores": [1]}, status_code=200) + + httpx_mock.add_callback(some_errors, url=i.endpoint) + add_instruction(index_docs, i) + await get_manager().running_tasks[index_docs, i.field] + for doc in TEST_DOCUMENTS: + result = get_document(index_docs, doc["_id"]) + assert result[i.field]["status"] == "error" if "text" in doc["text"] else "done" + assert get_counts(index_docs, i.field) == dict(total=4, done=2, error=2) + + httpx_mock.reset(True) + httpx_mock.add_response(url=i.endpoint, json={"labels": ["sports"], "scores": [1]}) + reassign_preprocessing_errors(index_docs, i.field) + await get_manager().running_tasks[index_docs, i.field] + for doc in TEST_DOCUMENTS: + result = get_document(index_docs, doc["_id"]) + assert result[i.field]["status"] == "done" + assert result["class_label"] == "sports" if "text" in doc["text"] else "politics" + assert len(httpx_mock.get_requests()) == 2 From 7d55290bdece00642ad54d8b96908158746a7138 Mon Sep 17 00:00:00 2001 From: Wouter van Atteveldt Date: Tue, 23 Apr 2024 15:58:35 +1000 Subject: [PATCH 77/80] Swith multimedia tests to mock minio --- setup.py | 13 ++++++++++++- tests/conftest.py | 19 +++++++++---------- tests/test_api_multimedia.py | 5 +++-- tests/test_multimedia.py | 2 ++ 4 files changed, 26 insertions(+), 13 deletions(-) diff --git a/setup.py b/setup.py index 52b5edb..d16748d 100644 --- a/setup.py +++ b/setup.py @@ -34,6 +34,17 @@ "minio", "jsonpath_ng", ], - extras_require={"dev": ["pytest", "pytest-httpx", "mypy", "flake8", "responses", "pre-commit", "types-requests"]}, + extras_require={ + "dev": [ + "pytest", + "pytest-httpx", + "mypy", + "flake8", + "responses", + "pre-commit", + "types-requests", + "pytest-minio-mock @ git+ssh://git@github.com/vanatteveldt/pytest-minio-mock.git", + ] + }, entry_points={"console_scripts": ["amcat4 = amcat4.__main__:main"]}, ) diff --git a/tests/conftest.py b/tests/conftest.py index 1574702..e7aa8a9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -33,9 +33,7 @@ def mock_middlecat(): get_settings().middlecat_url = "http://localhost:5000" get_settings().host = "http://localhost:3000" - minio = get_settings().minio_host - passthru = (f"http://{minio}",) if minio else () - with responses.RequestsMock(passthru_prefixes=passthru, assert_all_requests_are_fired=False) as resp: + with responses.RequestsMock(assert_all_requests_are_fired=False) as resp: resp.get("http://localhost:5000/api/configuration", json={"public_key": PUBLIC_KEY}) yield None @@ -217,10 +215,11 @@ def app(): @pytest.fixture() -def minio(): - minio = multimedia.connect_minio() - if not minio: - pytest.skip("No minio connected, skipping multimedia tests") - for index in ["amcat4_unittest_index"]: - multimedia.delete_bucket(minio, index) - return minio +def minio(minio_mock): + from minio.deleteobjects import DeleteObject + + minio = multimedia.get_minio() + for bucket in minio.list_buckets(): + for x in minio.list_objects(bucket.name, recursive=True): + minio.remove_object(x.bucket_name, x.object_name) + minio.remove_bucket(bucket.name) diff --git a/tests/test_api_multimedia.py b/tests/test_api_multimedia.py index 7660a6a..ce28495 100644 --- a/tests/test_api_multimedia.py +++ b/tests/test_api_multimedia.py @@ -1,4 +1,5 @@ from fastapi.testclient import TestClient +import pytest import requests from amcat4 import multimedia from amcat4.index import set_role, Role @@ -24,6 +25,7 @@ def test_authorisation(minio, client, index, user, reader): def test_post_get_list(minio, client, index, user): + pytest.skip("mock minio does not allow presigned post, skipping for now") set_role(index, user, Role.WRITER) assert _get_names(client, index, user) == set() post = client.get(f"index/{index}/multimedia/presigned_post", headers=build_headers(user)).json() @@ -53,8 +55,7 @@ def test_list_options(minio, client, index, reader): f"index/{index}/multimedia/list", params=dict(prefix="myfolder/", presigned_get=True), headers=build_headers(reader) ) res.raise_for_status() - urls = {o["key"]: o["presigned_get"] for o in res.json()} - assert requests.get(urls["myfolder/a1"]).content == b"a1" + assert all("presigned_get" in o for o in res.json()) def test_list_pagination(minio, client, index, reader): diff --git a/tests/test_multimedia.py b/tests/test_multimedia.py index 042b1a0..561d097 100644 --- a/tests/test_multimedia.py +++ b/tests/test_multimedia.py @@ -1,5 +1,6 @@ from io import BytesIO import os +import pytest import requests from amcat4 import multimedia @@ -11,6 +12,7 @@ def test_upload_get_multimedia(minio, index): def test_presigned_form(minio, index): + pytest.skip("mock minio does not allow presigned post, skipping for now") assert list(multimedia.list_multimedia_objects(index)) == [] bytes = os.urandom(32) key = "image.png" From fe7cdf71ebaa6481925141a7c0e21577273d8cf3 Mon Sep 17 00:00:00 2001 From: Wouter van Atteveldt Date: Wed, 24 Apr 2024 11:53:17 +1000 Subject: [PATCH 78/80] Added API endpoints for restarting/reassiging tasks --- amcat4/api/preprocessing.py | 39 +++++++++++++- amcat4/index.py | 11 ++++ amcat4/preprocessing/processor.py | 40 ++++++++------- setup.py | 1 + tests/conftest.py | 14 ++++- tests/test_api_preprocessing.py | 85 ++++++++++++++++++++++++++++++- tests/test_preprocessing.py | 6 +-- tests/tools.py | 14 +++-- 8 files changed, 179 insertions(+), 31 deletions(-) diff --git a/amcat4/api/preprocessing.py b/amcat4/api/preprocessing.py index a035eec..0c3f2f1 100644 --- a/amcat4/api/preprocessing.py +++ b/amcat4/api/preprocessing.py @@ -1,10 +1,19 @@ +import asyncio import logging -from fastapi import APIRouter, Depends, HTTPException, Response, status +from typing import Annotated, Literal +from fastapi import APIRouter, Body, Depends, HTTPException, Response, status from amcat4 import index from amcat4.api.auth import authenticated_user, check_role from amcat4.preprocessing.models import PreprocessingInstruction -from amcat4.index import get_instruction, get_instructions, add_instruction +from amcat4.index import ( + get_instruction, + get_instructions, + add_instruction, + reassign_preprocessing_errors, + start_preprocessor, + stop_preprocessor, +) from amcat4.preprocessing.processor import get_counts, get_manager from amcat4.preprocessing.task import get_tasks @@ -42,3 +51,29 @@ async def get_instruction_details(ix: str, field: str, user: str = Depends(authe state = get_manager().get_status(ix, field) counts = get_counts(ix, field) return dict(instruction=i, status=state, counts=counts) + + +@app_preprocessing.get("/index/{ix}/preprocessing/{field}/status") +async def get_status(ix: str, field: str, user: str = Depends(authenticated_user)): + return dict(status=get_manager().get_status(ix, field)) + + +@app_preprocessing.post( + "/index/{ix}/preprocessing/{field}/status", status_code=status.HTTP_204_NO_CONTENT, response_class=Response +) +async def set_status( + ix: str, + field: str, + user: str = Depends(authenticated_user), + action: Literal["Start", "Stop", "Reassign"] = Body(description="Status to set for this preprocessing task", embed=True), +): + check_role(user, index.Role.WRITER, ix) + current_status = get_manager().get_status(ix, field) + if action == "Start" and current_status in {"Unknown", "Error", "Stopped", "Done"}: + start_preprocessor(ix, field) + elif action == "Stop" and current_status in {"Active"}: + stop_preprocessor(ix, field) + elif action == "Reassign": + reassign_preprocessing_errors(ix, field) + else: + raise HTTPException(422, f"Cannot {action}, (status: {current_status}; field {ix}.{field})") diff --git a/amcat4/index.py b/amcat4/index.py index 49b3ea7..1aa4690 100644 --- a/amcat4/index.py +++ b/amcat4/index.py @@ -580,6 +580,9 @@ def update_tag_by_query(index: str | list[str], action: Literal["add", "remove"] return update_by_query(index, script, query, params) +### WvA Should probably move these to multimedia/actions or something + + def get_instructions(index: str) -> Iterable[PreprocessingInstruction]: res = es().get(index=get_settings().system_index, id=index, source="preprocessing") for i in res["_source"].get("preprocessing", []): @@ -611,3 +614,11 @@ def reassign_preprocessing_errors(index: str, field: str): query = dict(query=dict(term={f"{field}.status": dict(value="error")})) update_by_query(index, "ctx._source[params.field] = null", query, dict(field=field)) processor.get_manager().start_preprocessor(index, field) + + +def stop_preprocessor(index: str, field: str): + processor.get_manager().stop_preprocessor(index, field) + + +def start_preprocessor(index: str, field: str): + processor.get_manager().start_preprocessor(index, field) diff --git a/amcat4/preprocessing/processor.py b/amcat4/preprocessing/processor.py index a1b5f07..26f266a 100644 --- a/amcat4/preprocessing/processor.py +++ b/amcat4/preprocessing/processor.py @@ -10,7 +10,7 @@ logger = logging.getLogger("amcat4.preprocessing") -PreprocessorStatus = Literal["Active", "Paused", "Unknown", "Error", "Stopped"] +PreprocessorStatus = Literal["Active", "Paused", "Unknown", "Error", "Stopped", "Done"] class RateLimit(Exception): @@ -41,7 +41,8 @@ def start_preprocessor(self, index: str, field: str): if not existing_task.done: return existing_task instruction = self.preprocessors[index, field] - self.running_tasks[index, instruction.field] = asyncio.create_task(run_processor_loop(index, instruction)) + task = asyncio.create_task(run_processor_loop(index, instruction)) + self.running_tasks[index, instruction.field] = task def start_index_preprocessors(self, index: str): for ix, field in self.preprocessors: @@ -54,7 +55,7 @@ def stop_preprocessor(self, index: str, field: str): if task := self.running_tasks.get((index, field)): task.cancel() except: - logging.exception(f"Error on cancelling preprocessor {index}:{field}") + logger.exception(f"Error on cancelling preprocessor {index}:{field}") def remove_preprocessor(self, index: str, field: str): """Stop this preprocessor remove them from the manager""" @@ -68,15 +69,13 @@ def remove_index_preprocessors(self, index: str): if index == ix: self.remove_preprocessor(ix, field) - def shutdown(self): - for task in self.running_tasks.values(): - try: - task.cancel() - except: - logging.exception(f"Error on cancelling preprocessor") - def get_status(self, index: str, field: str) -> PreprocessorStatus: - return self.preprocessor_status.get((index, field), "Unknown") + status = self.preprocessor_status.get((index, field), "Unknown") + task = self.running_tasks.get((index, field)) + if (not task) or task.done() and status == "Active": + logger.warning(f"Preprocessor {index}.{field} is {status}, but has no running task: {task}") + return "Unknown" + return status @cache @@ -91,7 +90,7 @@ def start_processors(): try: instructions = list(amcat4.index.get_instructions(index.id)) except NotFoundError: - logging.warning(f"Index {index.id} does not exist!") + logger.warning(f"Index {index.id} does not exist!") continue for instruction in instructions: manager.add_preprocessor(index.id, instruction) @@ -102,23 +101,26 @@ async def run_processor_loop(index, instruction: PreprocessingInstruction): Main preprocessor loop. Calls process_documents to process a batch of documents, until 'done' """ - # TODO: add logic for pausing processing on hitting rate limit - logger.info(f"Starting preprocessing loop for {index}.{instruction.field}") + logger.info(f"Preprocessing START for {index}.{instruction.field}") get_manager().set_status(index, instruction.field, "Active") done = False while not done: try: done = await process_documents(index, instruction) + except asyncio.CancelledError: + logger.info(f"Preprocessing CANCEL for {index}.{instruction.field} cancelled") + get_manager().set_status(index, instruction.field, "Stopped") + raise except RateLimit: - logger.info(f"Pausing preprocessing loop for {index}.{instruction.field}") + logger.info(f"Peprocessing RATELIMIT for {index}.{instruction.field}") get_manager().set_status(index, instruction.field, "Paused") await asyncio.sleep(PAUSE_ON_RATE_LIMIT_SECONDS) except Exception: - logger.exception(f"Error on preprocessing {index}.{instruction.field}") + logger.exception(f"Preprocessing ERROR for {index}.{instruction.field}") get_manager().set_status(index, instruction.field, "Error") return - get_manager().set_status(index, instruction.field, "Stopped") - logger.info(f"Stopping preprocessing loop for {index}.{instruction.field}") + get_manager().set_status(index, instruction.field, "Done") + logger.info(f"Preprocessing DONE for {index}.{instruction.field}") async def process_documents(index: str, instruction: PreprocessingInstruction, size=100): @@ -160,7 +162,7 @@ async def process_doc(index: str, instruction: PreprocessingInstruction, doc: di try: req = instruction.build_request(index, doc) except Exception as e: - logging.exception(f"Error on preprocessing {index}.{instruction.field} doc {doc['_id']}") + logger.exception(f"Error on preprocessing {index}.{instruction.field} doc {doc['_id']}") amcat4.index.update_document(index, doc["_id"], {instruction.field: dict(status="error", error=str(e))}) return try: diff --git a/setup.py b/setup.py index d16748d..26c13fe 100644 --- a/setup.py +++ b/setup.py @@ -43,6 +43,7 @@ "responses", "pre-commit", "types-requests", + "pytest-asyncio", "pytest-minio-mock @ git+ssh://git@github.com/vanatteveldt/pytest-minio-mock.git", ] }, diff --git a/tests/conftest.py b/tests/conftest.py index e7aa8a9..ea97252 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,10 @@ -from typing import Any +import logging +from typing import Any, AsyncGenerator, AsyncIterable import pytest +import pytest_asyncio import responses from fastapi.testclient import TestClient +from httpx import ASGITransport, AsyncClient from amcat4 import api, multimedia # noqa: E402 from amcat4.config import get_settings, AuthOptions @@ -221,5 +224,12 @@ def minio(minio_mock): minio = multimedia.get_minio() for bucket in minio.list_buckets(): for x in minio.list_objects(bucket.name, recursive=True): - minio.remove_object(x.bucket_name, x.object_name) + minio.remove_object(x.bucket_name, x.object_name or "") minio.remove_bucket(bucket.name) + + +@pytest_asyncio.fixture +async def aclient(app) -> AsyncIterable[AsyncClient]: + host = get_settings().host + async with AsyncClient(transport=ASGITransport(app=app), base_url=host) as c: + yield c diff --git a/tests/test_api_preprocessing.py b/tests/test_api_preprocessing.py index a05e72d..85e5d28 100644 --- a/tests/test_api_preprocessing.py +++ b/tests/test_api_preprocessing.py @@ -1,11 +1,17 @@ import asyncio import json +import logging +import httpx import pytest -from amcat4.index import Role, get_document, refresh_index, set_role + +from amcat4.index import Role, add_instruction, get_document, refresh_index, set_role from amcat4.preprocessing.models import PreprocessingInstruction +from amcat4.preprocessing.processor import get_manager from tests.conftest import TEST_DOCUMENTS from tests.test_preprocessing import INSTRUCTION -from tests.tools import build_headers, check +from tests.tools import aget_json, build_headers, check, get_json + +logger = logging.getLogger("amcat4.tests") def test_get_tasks(client): @@ -49,3 +55,78 @@ async def test_post_get_instructions(client, user, index_docs, httpx_mock): # Cannot re-add the same field check(client.post(f"/index/{index_docs}/preprocessing", json=i.model_dump(), headers=build_headers(user=user)), 400) + + +@pytest.mark.asyncio +async def test_pause_restart(aclient: httpx.AsyncClient, admin, index_docs, httpx_mock, caplog): + async def slow_response(request): + json.loads(request.content)["inputs"] + await asyncio.sleep(0.1) + return httpx.Response(json={"labels": ["politics"], "scores": [1]}, status_code=200) + + i = PreprocessingInstruction.model_validate_json(json.dumps(INSTRUCTION)) + httpx_mock.add_callback(slow_response, url=i.endpoint) + status_url = f"/index/{index_docs}/preprocessing/{i.field}/status" + + # Start the preprocessor, wait .15 seconds + add_instruction(index_docs, i) + await asyncio.sleep(0.15) + + assert (await aget_json(aclient, status_url, user=admin))["status"] == "Active" + + # Set the processor to pause + check(await aclient.post(status_url, json=dict(action="Stop"), headers=build_headers(user=admin)), 204) + await asyncio.sleep(0) + assert (await aget_json(aclient, status_url, user=admin))["status"] == "Stopped" + + # Some, but not all docs should be done yet + assert len(httpx_mock.get_requests()) < len(TEST_DOCUMENTS) + assert len(httpx_mock.get_requests()) > 0 + + # Restart processor + check(await aclient.post(status_url, json=dict(action="Start"), headers=build_headers(user=admin)), 204) + await asyncio.sleep(0) + assert (await aget_json(aclient, status_url, user=admin))["status"] == "Active" + + await get_manager().running_tasks[index_docs, i.field] + assert (await aget_json(aclient, status_url, user=admin))["status"] == "Done" + + # There should be at most one extra request (the cancelled one) + assert len(httpx_mock.get_requests()) <= len(TEST_DOCUMENTS) + 1 + + +@pytest.mark.asyncio +async def test_reassign_error(aclient: httpx.AsyncClient, admin, index_docs, httpx_mock): + async def mistakes_were_made(request): + await asyncio.sleep(0.1) + input = json.loads(request.content)["inputs"] + if "text" in input: # should be true for 2 documents + return httpx.Response(json={"kettle": "black"}, status_code=418) + else: + return httpx.Response(json={"labels": ["first pass"]}, status_code=200) + + i = PreprocessingInstruction.model_validate_json(json.dumps(INSTRUCTION)) + httpx_mock.add_callback(mistakes_were_made, url=i.endpoint) + + # Start the preprocessor, wait .15 seconds + add_instruction(index_docs, i) + await get_manager().running_tasks[index_docs, i.field] + field_url = f"/index/{index_docs}/preprocessing/{i.field}" + status_url = f"{field_url}/status" + + res = await aget_json(aclient, field_url, user=admin) + assert res["status"] == "Done" + assert res["counts"] == {"total": 4, "done": 2, "error": 2} + + httpx_mock.reset(True) + httpx_mock.add_response(url=i.endpoint, json={"labels": ["secondpass"]}) + + check(await aclient.post(status_url, json=dict(action="Reassign"), headers=build_headers(user=admin)), 204) + await get_manager().running_tasks[index_docs, i.field] + + res = await aget_json(aclient, field_url, user=admin) + assert res["status"] == "Done" + assert res["counts"] == {"total": 4, "done": 4} + + # Check that only error'd documents are reassigned + assert len(httpx_mock.get_requests()) == 2 diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 5cabe42..34bc584 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -108,14 +108,14 @@ async def mock_slow_response(_request) -> httpx.Response: # Add the instruction. Since there are no documents, it should return instantly-ish add_instruction(index, i) await asyncio.sleep(0.1) - assert get_manager().get_status(index, i.field) == "Stopped" + assert get_manager().get_status(index, i.field) == "Done" # Add a document. The task should be re-activated and take half a second to complete upload_documents(index, [{"text": "text"}], fields={"text": "text"}) await asyncio.sleep(0.1) assert get_manager().get_status(index, i.field) == "Active" await asyncio.sleep(0.5) - assert get_manager().get_status(index, i.field) == "Stopped" + assert get_manager().get_status(index, i.field) == "Done" @pytest.mark.asyncio @@ -136,7 +136,7 @@ async def test_preprocess_ratelimit(index_docs, httpx_mock: HTTPXMock): httpx_mock.reset(assert_all_responses_were_requested=True) httpx_mock.add_response(url=i.endpoint, json={"labels": ["politics", "sports"], "scores": [0.9, 0.1]}) await asyncio.sleep(0.5) - assert get_manager().get_status(index_docs, i.field) == "Stopped" + assert get_manager().get_status(index_docs, i.field) == "Done" @pytest.mark.asyncio diff --git a/tests/tools.py b/tests/tools.py index e2651af..c522c91 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -5,6 +5,7 @@ from authlib.jose import jwt from fastapi.testclient import TestClient +from httpx import AsyncClient from amcat4.config import AuthOptions, get_settings from amcat4.index import refresh_index @@ -30,13 +31,20 @@ def build_headers(user=None, headers=None): return headers -def get_json(client: TestClient, url: str, expected=200, headers=None, user=None, **kargs): +def get_json(client: TestClient, url: str, expected=200, headers=None, user=None, **kargs) -> dict: """Get the given URL. If expected is 2xx, return the result as parsed json""" response = client.get(url, headers=build_headers(user, headers), **kargs) content = response.json() if response.content else None assert response.status_code == expected, f"GET {url} returned {response.status_code}, expected {expected}, {content}" - if expected // 100 == 2: - return content + return {} if content is None else content + + +async def aget_json(client: AsyncClient, url: str, expected=200, headers=None, user=None, **kargs) -> dict: + """Get the given URL. If expected is 2xx, return the result as parsed json""" + response = await client.get(url, headers=build_headers(user, headers), **kargs) + content = response.json() if response.content else None + assert response.status_code == expected, f"GET {url} returned {response.status_code}, expected {expected}, {content}" + return {} if content is None else content def post_json(client: TestClient, url, expected=201, headers=None, user=None, **kargs): From 7b1dfa7d175747b64a61c2a0937b67670ed392f4 Mon Sep 17 00:00:00 2001 From: Wouter van Atteveldt Date: Fri, 26 Apr 2024 19:37:18 +1000 Subject: [PATCH 79/80] Set to active after backoff sleep --- amcat4/preprocessing/processor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/amcat4/preprocessing/processor.py b/amcat4/preprocessing/processor.py index 26f266a..bfcb653 100644 --- a/amcat4/preprocessing/processor.py +++ b/amcat4/preprocessing/processor.py @@ -115,6 +115,7 @@ async def run_processor_loop(index, instruction: PreprocessingInstruction): logger.info(f"Peprocessing RATELIMIT for {index}.{instruction.field}") get_manager().set_status(index, instruction.field, "Paused") await asyncio.sleep(PAUSE_ON_RATE_LIMIT_SECONDS) + get_manager().set_status(index, instruction.field, "Active") except Exception: logger.exception(f"Preprocessing ERROR for {index}.{instruction.field}") get_manager().set_status(index, instruction.field, "Error") @@ -171,7 +172,7 @@ async def process_doc(index: str, instruction: PreprocessingInstruction, doc: di except HTTPStatusError as e: if e.response.status_code == 503: raise RateLimit(e) - logging.exception(f"Error on preprocessing {index}.{instruction.field} doc {doc['_id']}") + logging.exception(f"Error on preprocessing {index}.{instruction.field} doc {doc['_id']}: {e.response.text}") body = dict(status="error", status_code=e.response.status_code, response=e.response.text) amcat4.index.update_document(index, doc["_id"], {instruction.field: body}) return From dbfcb7c3a942abf9fbd461012a4e1e4092e463dd Mon Sep 17 00:00:00 2001 From: Wouter van Atteveldt Date: Fri, 26 Apr 2024 19:46:01 +1000 Subject: [PATCH 80/80] Add .venv to ignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 7440fed..a3e54ee 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ # .env environment variables .env +.venv # C extensions *.so