diff --git a/crossref/restful.py b/crossref/restful.py index 7e134f1..7874f7e 100644 --- a/crossref/restful.py +++ b/crossref/restful.py @@ -1,10 +1,12 @@ # coding: utf-8 import requests +import requests_cache from time import sleep from crossref import validators, VERSION + LIMIT = 100 MAXOFFSET = 10000 FACETS_MAX_LIMIT = 1000 @@ -132,6 +134,7 @@ class Endpoint: def __init__( self, + backend=None, request_url=None, request_params=None, context=None, @@ -139,7 +142,10 @@ def __init__( throttle=True, crossref_plus_token=None, timeout=30, + ): + if backend: + requests_cache.install_cache(cache_name='crossref_cache', backend=backend) self.do_http_request = HTTPRequest(throttle=throttle).do_http_request self.etiquette = etiquette or Etiquette() self.custom_header = {"user-agent": str(self.etiquette)} @@ -573,6 +579,10 @@ class Works(Endpoint): "update-type": None, } + def __init__(self, request_url=None, request_params=None, context=None, etiquette=None, throttle=True, crossref_plus_token=None, timeout=30, backend=None): + super().__init__(request_url, request_params, context, etiquette, throttle, crossref_plus_token, timeout, backend) + self.backend = backend + def order(self, order="asc"): """ This method retrieve an iterable object that implements the method @@ -631,6 +641,7 @@ def order(self, order="asc"): context=context, etiquette=self.etiquette, timeout=self.timeout, + backend=self.backend, ) def select(self, *args): @@ -729,6 +740,7 @@ def select(self, *args): context=context, etiquette=self.etiquette, timeout=self.timeout, + backend=self.backend, ) def sort(self, sort="score"): @@ -789,6 +801,7 @@ def sort(self, sort="score"): context=context, etiquette=self.etiquette, timeout=self.timeout, + backend=self.backend, ) def filter(self, **kwargs): @@ -841,6 +854,7 @@ def filter(self, **kwargs): context=context, etiquette=self.etiquette, timeout=self.timeout, + backend=self.backend, ) def facet(self, facet_name, facet_count=100): @@ -926,7 +940,8 @@ def query(self, *args, **kwargs): request_params=request_params, context=context, etiquette=self.etiquette, - timeout=self.timeout) + timeout=self.timeout, + backend=self.backend,) def sample(self, sample_size=20): """ @@ -971,6 +986,7 @@ def sample(self, sample_size=20): context=context, etiquette=self.etiquette, timeout=self.timeout, + backend=self.backend, ) def doi(self, doi, only_message=True): @@ -1112,6 +1128,10 @@ class Funders(Endpoint): "location": None, } + def __init__(self, request_url=None, request_params=None, context=None, etiquette=None, throttle=True, crossref_plus_token=None, timeout=30, backend=None,): + super().__init__(request_url, request_params, context, etiquette, throttle, crossref_plus_token, timeout, backend,) + self.backend = backend + def query(self, *args): """ This method retrieve an iterable object that implements the method @@ -1144,6 +1164,7 @@ def query(self, *args): request_params=request_params, etiquette=self.etiquette, timeout=self.timeout, + backend=self.backend, ) def filter(self, **kwargs): @@ -1197,6 +1218,7 @@ def filter(self, **kwargs): context=context, etiquette=self.etiquette, timeout=self.timeout, + backend=self.backend, ) def funder(self, funder_id, only_message=True): @@ -1298,6 +1320,10 @@ class Members(Endpoint): "current-doi-count": validators.is_integer, } + def __init__(self, request_url=None, request_params=None, context=None, etiquette=None, throttle=True, crossref_plus_token=None, timeout=30, backend=None,): + super().__init__(request_url, request_params, context, etiquette, throttle, crossref_plus_token, timeout, backend) + self.backend = backend + def query(self, *args): """ This method retrieve an iterable object that implements the method @@ -1349,6 +1375,7 @@ def query(self, *args): context=context, etiquette=self.etiquette, timeout=self.timeout, + backend=self.backend, ) def filter(self, **kwargs): @@ -1401,6 +1428,7 @@ def filter(self, **kwargs): context=context, etiquette=self.etiquette, timeout=self.timeout, + backend=self.backend, ) def member(self, member_id, only_message=True): @@ -1692,6 +1720,10 @@ class Journals(Endpoint): ENDPOINT = "journals" + def __init__(self, request_url=None, request_params=None, context=None, etiquette=None, throttle=True, crossref_plus_token=None, timeout=30, backend=None): + super().__init__(request_url, request_params, context, etiquette, throttle, crossref_plus_token, timeout, backend) + self.backend=backend + def query(self, *args): """ This method retrieve an iterable object that implements the method @@ -1726,6 +1758,7 @@ def query(self, *args): context=context, etiquette=self.etiquette, timeout=self.timeout, + backend=self.backend, ) def journal(self, issn, only_message=True): diff --git a/setup.py b/setup.py index 2e623c3..dd52f21 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,8 @@ from crossref import VERSION install_requires = [ - 'requests>=2.11.1' + 'requests>=2.11.1', + 'requests-cache>=1.0.0', ] tests_require = [] diff --git a/tests/test_restful.py b/tests/test_restful.py index 95619db..0a497fe 100644 --- a/tests/test_restful.py +++ b/tests/test_restful.py @@ -1,6 +1,7 @@ # coding: utf-8 import unittest +import pathlib as pl from crossref import restful from crossref import VERSION @@ -67,6 +68,14 @@ def test_work_with_sample_and_filters(self): self.assertEqual(result, 'https://api.crossref.org/works?filter=type%3Ajournal-article&sample=5') + def test_work_with_backend(self): + _ = restful.Works(backend='filesystem') + path = 'crossref_cache' + + if not pl.Path(path).resolve().is_dir(): + raise AssertionError("File does not exist: %s" % str(path)) + + def test_members_filters(self): result = restful.Members(etiquette=self.etiquette).filter(has_public_references="true").url