Skip to content

Commit

Permalink
fix total for progress bar by counting deprecated docs (#957)
Browse files Browse the repository at this point in the history
* docstring fix

* raise RestError on 400

* fix total by also counting deprecated docs

* ruff
  • Loading branch information
tschaume authored Dec 16, 2024
1 parent f9936de commit e6797ce
Showing 1 changed file with 40 additions and 27 deletions.
67 changes: 40 additions & 27 deletions mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import inspect
import itertools
import json
import os
Expand Down Expand Up @@ -419,7 +420,7 @@ def _query_resource(
"""Query the endpoint for a Resource containing a list of documents
and meta information about pagination and total document count.
For the end-user, methods .query() and .count() are intended to be
For the end-user, methods .search() and .count() are intended to be
easier to use.
Arguments:
Expand Down Expand Up @@ -994,8 +995,9 @@ def _submit_request_and_process(
)

if response.status_code in [400]:
warnings.warn(
f"The server does not support the request made to {response.url}. This may be due to an outdated mp-api package, or a problem with the query."
raise MPRestError(
f"The server does not support the request made to {response.url}. "
"This may be due to an outdated mp-api package, or a problem with the query."
)

if response.status_code == 200:
Expand Down Expand Up @@ -1266,34 +1268,45 @@ def count(self, criteria: dict | None = None) -> int | str:
"""Return a count of total documents.
Args:
criteria (dict | None): As in .query(). Defaults to None
criteria (dict | None): As in .search(). Defaults to None
Returns:
(int | str): Count of total results, or string indicating error
"""
try:
criteria = criteria or {}
user_preferences = (
self.monty_decode,
self.use_document_model,
self.mute_progress_bars,
)
self.monty_decode, self.use_document_model, self.mute_progress_bars = (
False,
False,
True,
) # do not waste cycles decoding
results = self._query_resource(
criteria=criteria, num_chunks=1, chunk_size=1
)
(
self.monty_decode,
self.use_document_model,
self.mute_progress_bars,
) = user_preferences
return results["meta"]["total_doc"]
except Exception: # pragma: no cover
return "Problem getting count"
criteria = criteria or {}
user_preferences = (
self.monty_decode,
self.use_document_model,
self.mute_progress_bars,
)
self.monty_decode, self.use_document_model, self.mute_progress_bars = (
False,
False,
True,
) # do not waste cycles decoding
results = self._query_resource(criteria=criteria, num_chunks=1, chunk_size=1)
cnt = results["meta"]["total_doc"]

no_query = not {field for field in criteria if field[0] != "_"}
if no_query and hasattr(self, "search"):
allowed_params = inspect.getfullargspec(self.search).args
if "deprecated" in allowed_params:
criteria["deprecated"] = True
results = self._query_resource(
criteria=criteria, num_chunks=1, chunk_size=1
)
cnt += results["meta"]["total_doc"]
warnings.warn(
"Omitting a query also includes deprecated documents in the results. "
"Make sure to post-filter them out."
)

(
self.monty_decode,
self.use_document_model,
self.mute_progress_bars,
) = user_preferences
return cnt

@property
def available_fields(self) -> list[str]:
Expand Down

0 comments on commit e6797ce

Please sign in to comment.