Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cache functionality #50

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion crossref/restful.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# coding: utf-8

import requests
import requests_cache
from time import sleep

from crossref import validators, VERSION


LIMIT = 100
MAXOFFSET = 10000
FACETS_MAX_LIMIT = 1000
Expand Down Expand Up @@ -132,14 +134,18 @@ class Endpoint:

def __init__(
self,
backend=None,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

backend parameter should be the last one. Some users don't use named parameters and your commit will break their scripts.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought, since all the parameters had default values, if I gave backend also a default value it wouldn't break anything. But it is something that I can change it, so OK.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I gave a second thought, and I understood what you are saying. You're right. I will change it as soon as possible.

request_url=None,
request_params=None,
context=None,
etiquette=None,
throttle=True,
crossref_plus_token=None,
timeout=30,

):
if backend:
requests_cache.install_cache(cache_name='crossref_cache', backend=backend)
self.do_http_request = HTTPRequest(throttle=throttle).do_http_request
self.etiquette = etiquette or Etiquette()
self.custom_header = {"user-agent": str(self.etiquette)}
Expand Down Expand Up @@ -573,6 +579,10 @@ class Works(Endpoint):
"update-type": None,
}

def __init__(self, request_url=None, request_params=None, context=None, etiquette=None, throttle=True, crossref_plus_token=None, timeout=30, backend=None):
super().__init__(request_url, request_params, context, etiquette, throttle, crossref_plus_token, timeout, backend)
self.backend = backend

def order(self, order="asc"):
"""
This method retrieve an iterable object that implements the method
Expand Down Expand Up @@ -631,6 +641,7 @@ def order(self, order="asc"):
context=context,
etiquette=self.etiquette,
timeout=self.timeout,
backend=self.backend,
)

def select(self, *args):
Expand Down Expand Up @@ -729,6 +740,7 @@ def select(self, *args):
context=context,
etiquette=self.etiquette,
timeout=self.timeout,
backend=self.backend,
)

def sort(self, sort="score"):
Expand Down Expand Up @@ -789,6 +801,7 @@ def sort(self, sort="score"):
context=context,
etiquette=self.etiquette,
timeout=self.timeout,
backend=self.backend,
)

def filter(self, **kwargs):
Expand Down Expand Up @@ -841,6 +854,7 @@ def filter(self, **kwargs):
context=context,
etiquette=self.etiquette,
timeout=self.timeout,
backend=self.backend,
)

def facet(self, facet_name, facet_count=100):
Expand Down Expand Up @@ -926,7 +940,8 @@ def query(self, *args, **kwargs):
request_params=request_params,
context=context,
etiquette=self.etiquette,
timeout=self.timeout)
timeout=self.timeout,
backend=self.backend,)

def sample(self, sample_size=20):
"""
Expand Down Expand Up @@ -971,6 +986,7 @@ def sample(self, sample_size=20):
context=context,
etiquette=self.etiquette,
timeout=self.timeout,
backend=self.backend,
)

def doi(self, doi, only_message=True):
Expand Down Expand Up @@ -1112,6 +1128,10 @@ class Funders(Endpoint):
"location": None,
}

def __init__(self, request_url=None, request_params=None, context=None, etiquette=None, throttle=True, crossref_plus_token=None, timeout=30, backend=None,):
super().__init__(request_url, request_params, context, etiquette, throttle, crossref_plus_token, timeout, backend,)
self.backend = backend

def query(self, *args):
"""
This method retrieve an iterable object that implements the method
Expand Down Expand Up @@ -1144,6 +1164,7 @@ def query(self, *args):
request_params=request_params,
etiquette=self.etiquette,
timeout=self.timeout,
backend=self.backend,
)

def filter(self, **kwargs):
Expand Down Expand Up @@ -1197,6 +1218,7 @@ def filter(self, **kwargs):
context=context,
etiquette=self.etiquette,
timeout=self.timeout,
backend=self.backend,
)

def funder(self, funder_id, only_message=True):
Expand Down Expand Up @@ -1298,6 +1320,10 @@ class Members(Endpoint):
"current-doi-count": validators.is_integer,
}

def __init__(self, request_url=None, request_params=None, context=None, etiquette=None, throttle=True, crossref_plus_token=None, timeout=30, backend=None,):
super().__init__(request_url, request_params, context, etiquette, throttle, crossref_plus_token, timeout, backend)
self.backend = backend

def query(self, *args):
"""
This method retrieve an iterable object that implements the method
Expand Down Expand Up @@ -1349,6 +1375,7 @@ def query(self, *args):
context=context,
etiquette=self.etiquette,
timeout=self.timeout,
backend=self.backend,
)

def filter(self, **kwargs):
Expand Down Expand Up @@ -1401,6 +1428,7 @@ def filter(self, **kwargs):
context=context,
etiquette=self.etiquette,
timeout=self.timeout,
backend=self.backend,
)

def member(self, member_id, only_message=True):
Expand Down Expand Up @@ -1692,6 +1720,10 @@ class Journals(Endpoint):

ENDPOINT = "journals"

def __init__(self, request_url=None, request_params=None, context=None, etiquette=None, throttle=True, crossref_plus_token=None, timeout=30, backend=None):
super().__init__(request_url, request_params, context, etiquette, throttle, crossref_plus_token, timeout, backend)
self.backend=backend

def query(self, *args):
"""
This method retrieve an iterable object that implements the method
Expand Down Expand Up @@ -1726,6 +1758,7 @@ def query(self, *args):
context=context,
etiquette=self.etiquette,
timeout=self.timeout,
backend=self.backend,
)

def journal(self, issn, only_message=True):
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from crossref import VERSION

install_requires = [
'requests>=2.11.1'
'requests>=2.11.1',
'requests-cache>=1.0.0',
]

tests_require = []
Expand Down
9 changes: 9 additions & 0 deletions tests/test_restful.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# coding: utf-8

import unittest
import pathlib as pl

from crossref import restful
from crossref import VERSION
Expand Down Expand Up @@ -67,6 +68,14 @@ def test_work_with_sample_and_filters(self):

self.assertEqual(result, 'https://api.crossref.org/works?filter=type%3Ajournal-article&sample=5')

def test_work_with_backend(self):
_ = restful.Works(backend='filesystem')
path = 'crossref_cache'

if not pl.Path(path).resolve().is_dir():
raise AssertionError("File does not exist: %s" % str(path))


def test_members_filters(self):
result = restful.Members(etiquette=self.etiquette).filter(has_public_references="true").url

Expand Down