Skip to content

Commit

Permalink
Add update_by_query endpoint (fixes #54)
Browse files Browse the repository at this point in the history
  • Loading branch information
vanatteveldt committed Jun 5, 2024
1 parent bd37050 commit d47823b
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 16 deletions.
40 changes: 36 additions & 4 deletions amcat4/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@

from amcat4 import query, aggregate
from amcat4.aggregate import Axis, Aggregation
from amcat4.api.auth import authenticated_user, check_fields_access
from amcat4.api.auth import authenticated_user, check_fields_access, check_role
from amcat4.config import AuthOptions, get_settings
from amcat4.fields import create_fields
from amcat4.index import Role, get_role, get_fields
from amcat4.index import Role, get_role, get_fields, update_documents_by_query
from amcat4.models import FieldSpec, FilterSpec, FilterValue, SortSpec
from amcat4.query import update_tag_query
from amcat4.query import update_query, update_tag_query

app_query = APIRouter(prefix="/index", tags=["query"])

Expand Down Expand Up @@ -303,8 +303,10 @@ def query_aggregate_post(
Returns a JSON object {data: [{axis1, ..., n, aggregate1, ...}, ...], meta: {axes: [...], aggregations: [...]}
"""
# TODO check user rights on index

indices = index.split(",")
for index in indices:
check_role(user, Role.READER, index)
_axes = [Axis(**x.model_dump()) for x in axes] if axes else []
_aggregations = [Aggregation(**x.model_dump()) for x in aggregations] if aggregations else []

Expand Down Expand Up @@ -367,3 +369,33 @@ def query_update_tags(
indices, action, field, tag, _standardize_queries(queries), _standardize_filters(filters), ids
)
return update_result


@app_query.post("/{index}/update_by_query")
def update_by_query(
index: str,
field: Annotated[str, Body(description="Field to update")],
value: Annotated[str | None, Body(description="New value for the field, or null/None to delete field")],
queries: Annotated[
str | list[str] | dict[str, str] | None,
Body(
description="Query/Queries to run. Value should be a single query string, a list of query strings, "
"or a dict of {'label': 'query'}",
),
] = None,
filters: Annotated[
dict[str, FilterValue | list[FilterValue] | FilterSpec] | None,
Body(
description="Field filters, should be a dict of field names to filter specifications,"
"which can be either a value, a list of values, or a FilterSpec dict",
),
] = None,
user: str = Depends(authenticated_user),
):
"""
Update documents by query.
Select documents using queries and/or filters, and specify a field and new value.
"""
check_role(user, Role.WRITER, index)
return update_query(index, field, value, _standardize_queries(queries), _standardize_filters(filters))
26 changes: 15 additions & 11 deletions amcat4/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,12 +542,6 @@ def delete_document(index: str, doc_id: str):
es().delete(index=index, id=doc_id)


def update_by_query(index: str | list[str], script: str, query: dict, params: dict | None = None):
script_dict = dict(source=script, lang="painless", params=params or {})
result = es().update_by_query(index=index, script=script_dict, **query, refresh=True)
return dict(updated=result["updated"], total=result["total"])


UDATE_SCRIPTS = dict(
add="""
if (ctx._source[params.field] == null) {
Expand Down Expand Up @@ -575,9 +569,19 @@ def update_by_query(index: str | list[str], script: str, query: dict, params: di

def update_tag_by_query(index: str | list[str], action: Literal["add", "remove"], query: dict, field: str, tag: str):
create_or_verify_tag_field(index, field)
script = UDATE_SCRIPTS[action]
params = dict(field=field, tag=tag)
return update_by_query(index, script, query, params)
script = dict(source=UDATE_SCRIPTS[action], lang="painless", params=dict(field=field, tag=tag))
result = es().update_by_query(index=index, script=script, **query, refresh=True)
return dict(updated=result["updated"], total=result["total"])


def update_documents_by_query(index: str | list[str], query: dict, field: str, value: Any):
if value is None:
script = dict(source="ctx._source.remove(params.field)", lang="painless", params=dict(field=field))
else:
script = dict(
source="ctx._source[params.field] = params.value", lang="painless", params=dict(field=field, value=value)
)
return es().update_by_query(index=index, query=query, script=script, refresh=True)


### WvA Should probably move these to multimedia/actions or something
Expand Down Expand Up @@ -611,8 +615,8 @@ def add_instruction(index: str, instruction: PreprocessingInstruction):

def reassign_preprocessing_errors(index: str, field: str):
"""Reset status for any documents with error status, and restart preprocessor"""
query = dict(query=dict(term={f"{field}.status": dict(value="error")}))
update_by_query(index, "ctx._source[params.field] = null", query, dict(field=field))
query = dict(term={f"{field}.status": dict(value="error")})
update_documents_by_query(index, query, field, None)
processor.get_manager().start_preprocessor(index, field)


Expand Down
14 changes: 13 additions & 1 deletion amcat4/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from .date_mappings import mappings
from .elastic import es
from amcat4.index import update_tag_by_query
from amcat4.index import update_documents_by_query, update_tag_by_query


def build_body(
Expand Down Expand Up @@ -273,3 +273,15 @@ def update_tag_query(

update_result = update_tag_by_query(index, action, body, field, tag)
return update_result


def update_query(
index: str | list[str],
field: str,
value: Any,
queries: dict[str, str] | None = None,
filters: dict[str, FilterSpec] | None = None,
ids: list[str] | None = None,
):
query = build_body(queries, filters, ids=ids)
return update_documents_by_query(index=index, query=query["query"], field=field, value=value)
14 changes: 14 additions & 0 deletions tests/test_api_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,17 @@ def tags():
)
assert res["updated"] == 2
assert tags() == {"1": ["y"], "2": ["x", "y"]}


def test_api_update_by_query(client, index_docs, admin):
def cats():
res = query_documents(index_docs, fields=[FieldSpec(name="cat"), FieldSpec(name="subcat")])
return {doc["_id"]: doc.get("subcat") for doc in (res.data if res else [])}

res = client.post(
f"/index/{index_docs}/update_by_query",
json=dict(field="subcat", value="z", filters=dict(cat="a")),
headers=build_headers(user=admin),
)
res.raise_for_status()
assert cats() == {"0": "z", "1": "z", "2": "z", "3": "y"}
14 changes: 14 additions & 0 deletions tests/test_elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from amcat4.index import (
refresh_index,
update_documents_by_query,
upload_documents,
get_document,
update_document,
Expand Down Expand Up @@ -75,6 +76,19 @@ def test_update(index_docs):
assert get_document(index_docs, "0", _source=["annotations"])["annotations"] == {"x": 3}


def test_update_by_query(index_docs):
def cats():
res = query_documents(index_docs, fields=[FieldSpec(name="cat"), FieldSpec(name="subcat")])
return {doc["_id"]: doc.get("subcat") for doc in (res.data if res else [])}

assert cats() == {"0": "x", "1": "x", "2": "y", "3": "y"}
update_documents_by_query(index_docs, query=dict(term={"cat": dict(value="a")}), field="subcat", value="z")
assert cats() == {"0": "z", "1": "z", "2": "z", "3": "y"}
update_documents_by_query(index_docs, query=dict(term={"cat": dict(value="b")}), field="subcat", value=None)
assert cats() == {"0": "z", "1": "z", "2": "z", "3": None}
assert "subcat" not in get_document(index_docs, "3").keys()


def test_add_tag(index_docs):
def q(*ids):
return dict(query=dict(ids={"values": ids}))
Expand Down

0 comments on commit d47823b

Please sign in to comment.