Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing requests session for customisation #170

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,14 @@ General Usage
First, create a new Zotero instance:


.. py:class:: Zotero(library_id, library_type[, api_key, preserve_json_order, locale])
.. py:class:: Zotero(library_id, library_type[, api_key, preserve_json_order, locale, session])
:param str library_id: a valid Zotero API user ID
:param str library_type: a valid Zotero API library type: **user** or **group**
:param str api_key: a valid Zotero API user key
:param bool preserve_json_order: Load JSON returns with OrderedDict to preserve their order
:param str locale: Set the `locale <https://www.zotero.org/support/dev/web_api/v3/types_and_fields#zotero_web_api_item_typefield_requests>`_, allowing retrieval of localised item types, field types, and creator types. Defaults to "en-US".
:param requests.Session session: a custom requests session, for example to use `requests-cache <https://pypi.org/project/requests-cache/>`_


Example:
Expand Down
46 changes: 24 additions & 22 deletions src/pyzotero/zotero.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def __init__(
api_key=None,
preserve_json_order=False,
locale="en-US",
session=None,
):
"""Store Zotero credentials"""
self.endpoint = "https://api.zotero.org"
Expand All @@ -303,6 +304,7 @@ def __init__(
self.api_key = api_key
self.preserve_json_order = preserve_json_order
self.locale = locale
self.s = session or requests.Session()
self.url_params = None
self.tag_data = False
self.request = None
Expand Down Expand Up @@ -416,7 +418,7 @@ def _retrieve_data(self, request=None, params=None):
self.self_link = request
# ensure that we wait if there's an active backoff
self._check_backoff()
self.request = requests.get(
self.request = self.s.get(
url=full_url, headers=self.default_headers(), params=params
)
self.request.encoding = "utf-8"
Expand Down Expand Up @@ -487,7 +489,7 @@ def _updated(self, url, payload, template=None):
headers.update(self.default_headers())
# perform the request, and check whether the response returns 304
self._check_backoff()
req = requests.get(query, headers=headers)
req = self.s.get(query, headers=headers)
try:
req.raise_for_status()
except requests.exceptions.HTTPError as exc:
Expand Down Expand Up @@ -615,7 +617,7 @@ def set_fulltext(self, itemkey, payload):
"""
headers = self.default_headers()
headers.update({"Content-Type": "application/json"})
return requests.put(
return self.s.put(
url=build_url(
self.endpoint,
"/{t}/{u}/items/{k}/fulltext".format(
Expand All @@ -636,7 +638,7 @@ def new_fulltext(self, since):
)
headers = self.default_headers()
self._check_backoff()
resp = requests.get(build_url(self.endpoint, query_string), headers=headers)
resp = self.s.get(build_url(self.endpoint, query_string), headers=headers)
try:
resp.raise_for_status()
except requests.exceptions.HTTPError as exc:
Expand Down Expand Up @@ -1026,7 +1028,7 @@ def saved_search(self, name, conditions):
headers = {"Zotero-Write-Token": token()}
headers.update(self.default_headers())
self._check_backoff()
req = requests.post(
req = self.s.post(
url=build_url(
self.endpoint,
"/{t}/{u}/searches".format(t=self.library_type, u=self.library_id),
Expand Down Expand Up @@ -1054,7 +1056,7 @@ def delete_saved_search(self, keys):
headers = {"Zotero-Write-Token": token()}
headers.update(self.default_headers())
self._check_backoff()
req = requests.delete(
req = self.s.delete(
url=build_url(
self.endpoint,
"/{t}/{u}/searches".format(t=self.library_type, u=self.library_id),
Expand Down Expand Up @@ -1230,7 +1232,7 @@ def create_items(self, payload, parentid=None, last_modified=None):
to_send = json.dumps([i for i in self._cleanup(*payload, allow=("key"))])
headers.update(self.default_headers())
self._check_backoff()
req = requests.post(
req = self.s.post(
url=build_url(
self.endpoint,
"/{t}/{u}/items".format(t=self.library_type, u=self.library_id),
Expand Down Expand Up @@ -1260,7 +1262,7 @@ def create_items(self, payload, parentid=None, last_modified=None):
for value in resp["success"].values():
payload = json.dumps({"parentItem": parentid})
self._check_backoff()
presp = requests.patch(
presp = self.s.patch(
url=build_url(
self.endpoint,
"/{t}/{u}/items/{v}".format(
Expand Down Expand Up @@ -1306,7 +1308,7 @@ def create_collections(self, payload, last_modified=None):
headers["If-Unmodified-Since-Version"] = str(last_modified)
headers.update(self.default_headers())
self._check_backoff()
req = requests.post(
req = self.s.post(
url=build_url(
self.endpoint,
"/{t}/{u}/collections".format(t=self.library_type, u=self.library_id),
Expand Down Expand Up @@ -1338,7 +1340,7 @@ def update_collection(self, payload, last_modified=None):
headers = {"If-Unmodified-Since-Version": str(modified)}
headers.update(self.default_headers())
headers.update({"Content-Type": "application/json"})
return requests.put(
return self.s.put(
url=build_url(
self.endpoint,
"/{t}/{u}/collections/{c}".format(
Expand Down Expand Up @@ -1397,7 +1399,7 @@ def update_item(self, payload, last_modified=None):
ident = payload["key"]
headers = {"If-Unmodified-Since-Version": str(modified)}
headers.update(self.default_headers())
return requests.patch(
return self.s.patch(
url=build_url(
self.endpoint,
"/{t}/{u}/items/{id}".format(
Expand All @@ -1420,7 +1422,7 @@ def update_items(self, payload):
# anything longer
for chunk in chunks(to_send, 50):
self._check_backoff()
req = requests.post(
req = self.s.post(
url=build_url(
self.endpoint,
"/{t}/{u}/items/".format(t=self.library_type, u=self.library_id),
Expand Down Expand Up @@ -1450,7 +1452,7 @@ def update_collections(self, payload):
# anything longer
for chunk in chunks(to_send, 50):
self._check_backoff()
req = requests.post(
req = self.s.post(
url=build_url(
self.endpoint,
"/{t}/{u}/collections/".format(
Expand Down Expand Up @@ -1483,7 +1485,7 @@ def addto_collection(self, collection, payload):
modified_collections = payload["data"]["collections"] + [collection]
headers = {"If-Unmodified-Since-Version": str(modified)}
headers.update(self.default_headers())
return requests.patch(
return self.s.patch(
url=build_url(
self.endpoint,
"/{t}/{u}/items/{i}".format(
Expand All @@ -1509,7 +1511,7 @@ def deletefrom_collection(self, collection, payload):
]
headers = {"If-Unmodified-Since-Version": str(modified)}
headers.update(self.default_headers())
return requests.patch(
return self.s.patch(
url=build_url(
self.endpoint,
"/{t}/{u}/items/{i}".format(
Expand All @@ -1536,7 +1538,7 @@ def delete_tags(self, *payload):
"If-Unmodified-Since-Version": self.request.headers["last-modified-version"]
}
headers.update(self.default_headers())
return requests.delete(
return self.s.delete(
url=build_url(
self.endpoint,
"/{t}/{u}/tags".format(t=self.library_type, u=self.library_id),
Expand Down Expand Up @@ -1578,7 +1580,7 @@ def delete_item(self, payload, last_modified=None):
)
headers = {"If-Unmodified-Since-Version": str(modified)}
headers.update(self.default_headers())
return requests.delete(url=url, params=params, headers=headers)
return self.s.delete(url=url, params=params, headers=headers)

@backoff_check
def delete_collection(self, payload, last_modified=None):
Expand Down Expand Up @@ -1613,7 +1615,7 @@ def delete_collection(self, payload, last_modified=None):
)
headers = {"If-Unmodified-Since-Version": str(modified)}
headers.update(self.default_headers())
return requests.delete(url=url, params=params, headers=headers)
return self.s.delete(url=url, params=params, headers=headers)


def error_handler(zot, req, exc=None):
Expand Down Expand Up @@ -1898,7 +1900,7 @@ def _create_prelim(self):
child["parentItem"] = self.parentid
to_send = json.dumps(self.payload)
self.zinstance._check_backoff()
req = requests.post(
req = self.s.post(
url=build_url(
self.zinstance.endpoint,
liblevel.format(
Expand Down Expand Up @@ -1946,7 +1948,7 @@ def _get_auth(self, attachment, reg_key, md5=None):
"params": 1,
}
self.zinstance._check_backoff()
auth_req = requests.post(
auth_req = self.s.post(
url=build_url(
self.zinstance.endpoint,
"/{t}/{u}/items/{i}/file".format(
Expand Down Expand Up @@ -1983,7 +1985,7 @@ def _upload_file(self, authdata, attachment, reg_key):
upload_pairs = tuple(upload_list)
try:
self.zinstance._check_backoff()
upload = requests.post(
upload = self.s.post(
url=authdata["url"],
files=upload_pairs,
headers={"User-Agent": "Pyzotero/%s" % pz.__version__},
Expand Down Expand Up @@ -2011,7 +2013,7 @@ def _register_upload(self, authdata, reg_key):
reg_headers.update(self.zinstance.default_headers())
reg_data = {"upload": authdata.get("uploadKey")}
self.zinstance._check_backoff()
upload_reg = requests.post(
upload_reg = self.s.post(
url=build_url(
self.zinstance.endpoint,
"/{t}/{u}/items/{i}/file".format(
Expand Down