Skip to content

Commit

Permalink
Merge pull request #447 from uhh-lt/fix-filter-search
Browse files Browse the repository at this point in the history
Fix filter search
  • Loading branch information
bigabig authored Oct 16, 2024
2 parents f1f3f0e + 559960d commit 80c8180
Show file tree
Hide file tree
Showing 23 changed files with 1,012 additions and 371 deletions.
167 changes: 143 additions & 24 deletions backend/src/api/endpoints/search.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import List, Optional

from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session

from api.dependencies import get_current_user, get_db_session
from api.dependencies import get_current_user
from app.core.authorization.authz_user import AuthzUser
from app.core.data.crud import Crud
from app.core.data.dto.search import (
Expand Down Expand Up @@ -48,14 +47,14 @@ def search_sdocs_info(
)
def search_sdocs(
*,
search_query: str,
project_id: int,
search_query: str,
expert_mode: bool,
filter: Filter[SearchColumns],
sorts: List[Sort[SearchColumns]],
highlight: bool,
page_number: Optional[int] = None,
page_size: Optional[int] = None,
filter: Filter[SearchColumns],
sorts: List[Sort[SearchColumns]],
authz_user: AuthzUser = Depends(),
) -> PaginatedElasticSearchDocumentHits:
authz_user.assert_in_project(project_id)
Expand All @@ -72,26 +71,66 @@ def search_sdocs(


@router.post(
"/code_stats",
"/code_stats_by_search",
response_model=List[SpanEntityStat],
summary="Returns SpanEntityStats for the given SourceDocuments.",
summary="Returns SpanEntityStats for the given search parameters.",
)
def search_code_stats(
*,
db: Session = Depends(get_db_session),
authz_user: AuthzUser = Depends(),
# code stat params
code_id: int,
sdoc_ids: List[int],
sort_by_global: bool = False,
authz_user: AuthzUser = Depends(),
# search params
project_id: int,
search_query: str,
expert_mode: bool,
filter: Filter[SearchColumns],
sorts: List[Sort[SearchColumns]],
) -> List[SpanEntityStat]:
# search for relevant sdoc_ids
authz_user.assert_in_project(project_id)
search_result = SearchService().search(
project_id=project_id,
search_query=search_query,
expert_mode=expert_mode,
filter=filter,
sorts=sorts,
highlight=False,
)
sdoc_ids = [hit.document_id for hit in search_result.hits]
if len(sdoc_ids) == 0:
return []

# compute code stats
authz_user.assert_in_same_project_as(Crud.CODE, code_id)
authz_user.assert_in_same_project_as_many(Crud.SOURCE_DOCUMENT, sdoc_ids)
code_stats = SearchService().compute_code_statistics(
code_id=code_id, sdoc_ids=set(sdoc_ids)
)
if sort_by_global:
code_stats.sort(key=lambda x: x.global_count, reverse=True)
return code_stats


# TODO Flo for large corpora this gets very slow. Hence we have to set a limit and in future implement some lazy
# loading or scrolling in the frontend with skip and limit.
@router.post(
"/code_stats_by_sdocs",
response_model=List[SpanEntityStat],
summary="Returns SpanEntityStats for the given SourceDocuments.",
)
def filter_code_stats(
*,
authz_user: AuthzUser = Depends(),
# code stat params
code_id: int,
sort_by_global: bool = False,
# filter params
sdoc_ids: List[int],
) -> List[SpanEntityStat]:
if len(sdoc_ids) == 0:
return []

# compute code stats
authz_user.assert_in_same_project_as(Crud.CODE, code_id)
code_stats = SearchService().compute_code_statistics(
code_id=code_id, sdoc_ids=set(sdoc_ids)
)
Expand All @@ -101,24 +140,65 @@ def search_code_stats(


@router.post(
"/keyword_stats",
"/keyword_stats_by_search",
response_model=List[KeywordStat],
summary="Returns KeywordStats for the given SourceDocuments.",
summary="Returns KeywordStats for the given seach parameters.",
)
def search_keyword_stats(
*,
authz_user: AuthzUser = Depends(),
project_id: int,
sdoc_ids: List[int],
# keyword stat params
sort_by_global: bool = False,
top_k: int = 50,
authz_user: AuthzUser = Depends(),
# search params
search_query: str,
expert_mode: bool,
filter: Filter[SearchColumns],
sorts: List[Sort[SearchColumns]],
) -> List[KeywordStat]:
# search for relevant sdoc_ids
authz_user.assert_in_project(project_id)
search_result = SearchService().search(
project_id=project_id,
search_query=search_query,
expert_mode=expert_mode,
filter=filter,
sorts=sorts,
highlight=False,
)
sdoc_ids = [hit.document_id for hit in search_result.hits]
if len(sdoc_ids) == 0:
return []

authz_user.assert_in_project(project_id)
authz_user.assert_in_same_project_as_many(Crud.SOURCE_DOCUMENT, sdoc_ids)
# compute keyword stats
keyword_stats = SearchService().compute_keyword_statistics(
proj_id=project_id, sdoc_ids=set(sdoc_ids), top_k=top_k
)
if sort_by_global:
keyword_stats.sort(key=lambda x: x.global_count, reverse=True)
return keyword_stats


@router.post(
"/keyword_stats_by_sdocs",
response_model=List[KeywordStat],
summary="Returns KeywordStats for the given SourceDocuments.",
)
def filter_keyword_stats(
*,
authz_user: AuthzUser = Depends(),
project_id: int,
# keyword stat params
sort_by_global: bool = False,
top_k: int = 50,
# filter params
sdoc_ids: List[int],
) -> List[KeywordStat]:
if len(sdoc_ids) == 0:
return []

# compute keyword stats
keyword_stats = SearchService().compute_keyword_statistics(
proj_id=project_id, sdoc_ids=set(sdoc_ids), top_k=top_k
)
Expand All @@ -128,21 +208,60 @@ def search_keyword_stats(


@router.post(
"/tag_stats",
"/tag_stats_by_search",
response_model=List[TagStat],
summary="Returns Stat for the given SourceDocuments.",
summary="Returns Stat for the given search parameters.",
)
def search_tag_stats(
*,
sdoc_ids: List[int],
sort_by_global: bool = False,
authz_user: AuthzUser = Depends(),
# keyword stat params
sort_by_global: bool = False,
# search params
project_id: int,
search_query: str,
expert_mode: bool,
filter: Filter[SearchColumns],
sorts: List[Sort[SearchColumns]],
) -> List[TagStat]:
# search for relevant sdoc_ids
authz_user.assert_in_project(project_id)
search_result = SearchService().search(
project_id=project_id,
search_query=search_query,
expert_mode=expert_mode,
filter=filter,
sorts=sorts,
highlight=False,
)
sdoc_ids = [hit.document_id for hit in search_result.hits]
if len(sdoc_ids) == 0:
return []

authz_user.assert_in_same_project_as_many(Crud.SOURCE_DOCUMENT, sdoc_ids)
# compute tag stats
tag_stats = SearchService().compute_tag_statistics(sdoc_ids=set(sdoc_ids))
if sort_by_global:
tag_stats.sort(key=lambda x: x.global_count, reverse=True)
return tag_stats


@router.post(
"/tag_stats_by_sdocs",
response_model=List[TagStat],
summary="Returns Stat for the given SourceDocuments.",
)
def filter_tag_stats(
*,
authz_user: AuthzUser = Depends(),
# keyword stat params
sort_by_global: bool = False,
# filter params
sdoc_ids: List[int],
) -> List[TagStat]:
if len(sdoc_ids) == 0:
return []

# compute tag stats
tag_stats = SearchService().compute_tag_statistics(sdoc_ids=set(sdoc_ids))
if sort_by_global:
tag_stats.sort(key=lambda x: x.global_count, reverse=True)
Expand Down
11 changes: 9 additions & 2 deletions backend/src/app/core/search/elasticsearch_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def search_sdocs_by_content_query(
self,
*,
proj_id: int,
sdoc_ids: Set[int],
sdoc_ids: Optional[Set[int]],
query: str,
use_simple_query: bool = True,
highlight: bool = False,
Expand All @@ -425,9 +425,16 @@ def search_sdocs_by_content_query(

highlight_query = {"fields": {"content": {}}} if highlight else None

# the sdoc_ids parameter is for filtering the search results
# if it is None, all documents are searched
bool_must_query = [q]
if sdoc_ids is not None:
# the terms query has an allowed maximum of 65536 terms
bool_must_query.append({"terms": {"sdoc_id": list(sdoc_ids)[:65536]}})

return self.__search_sdocs(
proj_id=proj_id,
query={"bool": {"must": [{"terms": {"sdoc_id": list(sdoc_ids)}}, q]}},
query={"bool": {"must": bool_must_query}},
limit=limit,
skip=skip,
highlight=highlight_query,
Expand Down
Loading

0 comments on commit 80c8180

Please sign in to comment.