From 1e41e28a41a8a0cc473df367dd3c20a60b775053 Mon Sep 17 00:00:00 2001 From: Abram Booth Date: Wed, 9 Aug 2023 16:55:59 -0400 Subject: [PATCH] api_tests.share overhaul --- api/share/utils.py | 95 +++++++-- api_tests/providers/test_reindex_provider.py | 22 +- api_tests/share/__init__.py | 0 api_tests/share/_utils.py | 27 +++ api_tests/share/test_share_node.py | 211 ++++--------------- api_tests/share/test_share_preprint.py | 195 ++++++++--------- conftest.py | 7 +- framework/celery_tasks/handlers.py | 4 +- osf/metadata/tools.py | 56 +---- osf/models/mixins.py | 2 + website/preprints/tasks.py | 2 + 11 files changed, 257 insertions(+), 364 deletions(-) create mode 100644 api_tests/share/__init__.py create mode 100644 api_tests/share/_utils.py diff --git a/api/share/utils.py b/api/share/utils.py index cb1d9b2994f..534414c9603 100644 --- a/api/share/utils.py +++ b/api/share/utils.py @@ -7,17 +7,23 @@ from celery.exceptions import Retry from django.apps import apps +import requests from framework.celery_tasks import app as celery_app from framework.celery_tasks.handlers import enqueue_task from framework.sentry import log_exception -from osf.metadata.tools import pls_send_trove_indexcard, pls_delete_trove_indexcard +from osf.metadata.osf_gathering import osf_iri +from osf.metadata.tools import pls_gather_metadata_file from website import settings logger = logging.getLogger(__name__) +def shtrove_ingest_url(): + return f'{settings.SHARE_URL}api/v3/ingest' + + def is_qa_resource(resource): """ QA puts tags and special titles on their project to stop them from appearing in the search results. This function @@ -46,29 +52,35 @@ def update_share(resource): enqueue_task(task__update_share.s(_osfguid_value)) +def do_update_share(osfguid: str): + logger.debug('%s.do_update_share("%s")', __name__, osfguid) + _guid_instance = apps.get_model('osf.Guid').load(osfguid) + if _guid_instance is None: + raise ValueError(f'unknown osfguid "{osfguid}"') + _resource = _guid_instance.referent + _response = ( + pls_delete_trove_indexcard(_resource) + if _should_delete_indexcard(_resource) + else pls_send_trove_indexcard(_resource) + ) + return _response + + @celery_app.task(bind=True, max_retries=4, acks_late=True) -def task__update_share(self, guid: str, **kwargs): +def task__update_share(self, guid: str): """ This function updates share takes Preprints, Projects and Registrations. :param self: :param guid: :return: """ - _guid_instance = apps.get_model('osf.Guid').load(guid) - if _guid_instance is None: - raise ValueError(f'unknown osfguid "{guid}"') - resource = _guid_instance.referent - resp = ( - pls_delete_trove_indexcard(resource) - if _should_delete_indexcard(resource) - else pls_send_trove_indexcard(resource) - ) + resp = do_update_share(guid) try: resp.raise_for_status() except Exception as e: if self.request.retries == self.max_retries: log_exception() - elif resp.status_code >= 500 and settings.USE_CELERY: + elif resp.status_code >= 500: try: self.retry( exc=e, @@ -82,13 +94,64 @@ def task__update_share(self, guid: str, **kwargs): return resp +def pls_send_trove_indexcard(osf_item): + _iri = osf_iri(osf_item) + if not _iri: + raise ValueError(f'could not get iri for {osf_item}') + _metadata_record = pls_gather_metadata_file(osf_item, 'turtle') + return requests.post( + shtrove_ingest_url(), + params={ + 'focus_iri': _iri, + 'record_identifier': _shtrove_record_identifier(osf_item), + }, + headers={ + 'Content-Type': _metadata_record.mediatype, + **_shtrove_auth_headers(osf_item), + }, + data=_metadata_record.serialized_metadata, + ) + + +def pls_delete_trove_indexcard(osf_item): + return requests.delete( + shtrove_ingest_url(), + params={ + 'record_identifier': _shtrove_record_identifier(osf_item), + }, + headers=_shtrove_auth_headers(osf_item), + ) + + +def _shtrove_record_identifier(osf_item): + return osf_item.guids.values_list('_id', flat=True).first() + + +def _shtrove_auth_headers(osf_item): + _nonfile_item = ( + osf_item.target + if hasattr(osf_item, 'target') + else osf_item + ) + _access_token = ( + _nonfile_item.provider.access_token + if getattr(_nonfile_item, 'provider', None) and _nonfile_item.provider.access_token + else settings.SHARE_API_TOKEN + ) + return {'Authorization': f'Bearer {_access_token}'} + + def _should_delete_indexcard(osf_item): + if getattr(osf_item, 'is_deleted', False) or getattr(osf_item, 'deleted', None): + return True # if it quacks like BaseFileNode, look at .target instead - _possibly_private_item = getattr(osf_item, 'target', None) or osf_item + _containing_item = getattr(osf_item, 'target', None) + if _containing_item: + return _should_delete_indexcard(_containing_item) return ( - not _is_item_public(_possibly_private_item) - or getattr(_possibly_private_item, 'is_spam', False) - or is_qa_resource(_possibly_private_item) + not _is_item_public(osf_item) + or getattr(osf_item, 'is_spam', False) + or is_qa_resource(osf_item) ) diff --git a/api_tests/providers/test_reindex_provider.py b/api_tests/providers/test_reindex_provider.py index 29ec3b57abf..3d15d44f55f 100644 --- a/api_tests/providers/test_reindex_provider.py +++ b/api_tests/providers/test_reindex_provider.py @@ -1,5 +1,6 @@ +from unittest import mock + import pytest -import json from django.core.management import call_command @@ -35,16 +36,15 @@ def registration(self, registration_provider): def user(self): return AuthUserFactory() - def test_reindex_provider_preprint(self, mock_share, preprint_provider, preprint): - call_command('reindex_provider', f'--providers={preprint_provider._id}') - data = json.loads(mock_share.calls[-1].request.body) + @pytest.fixture() + def mock_update_share(self): + with mock.patch('osf.management.commands.reindex_provider.update_share') as _mock_update_share: + yield _mock_update_share - assert any(graph for graph in data['data']['attributes']['data']['@graph'] - if graph['@type'] == preprint_provider.share_publish_type.lower()) + def test_reindex_provider_preprint(self, mock_update_share, preprint_provider, preprint): + call_command('reindex_provider', f'--providers={preprint_provider._id}') + assert mock_update_share.called_once_with(preprint) - def test_reindex_provider_registration(self, mock_share, registration_provider, registration): + def test_reindex_provider_registration(self, mock_update_share, registration_provider, registration): call_command('reindex_provider', f'--providers={registration_provider._id}') - data = json.loads(mock_share.calls[-1].request.body) - - assert any(graph for graph in data['data']['attributes']['data']['@graph'] - if graph['@type'] == registration_provider.share_publish_type.lower()) + assert mock_update_share.called_once_with(registration) diff --git a/api_tests/share/__init__.py b/api_tests/share/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/api_tests/share/_utils.py b/api_tests/share/_utils.py new file mode 100644 index 00000000000..caa6bd066b0 --- /dev/null +++ b/api_tests/share/_utils.py @@ -0,0 +1,27 @@ +import contextlib +from urllib.parse import urlsplit + +from django.http import QueryDict + +from website import settings as website_settings + + +@contextlib.contextmanager +def expect_ingest_request(mock_share, osfguid, *, token=None, delete=False, count=1): + mock_share._calls.reset() + yield + assert len(mock_share.calls) == count + for _call in mock_share.calls: + assert_ingest_request(_call.request, osfguid, token=token, delete=delete) + + +def assert_ingest_request(request, expected_osfguid, *, token=None, delete=False): + _querydict = QueryDict(urlsplit(request.path_url).query) + assert _querydict['record_identifier'] == expected_osfguid + if delete: + assert request.method == 'DELETE' + else: + assert request.method == 'POST' + assert _querydict['focus_iri'] == f'{website_settings.DOMAIN}{expected_osfguid}' + _token = token or website_settings.SHARE_API_TOKEN + assert request.headers['Authorization'] == f'Bearer {_token}' diff --git a/api_tests/share/test_share_node.py b/api_tests/share/test_share_node.py index b2f401be354..7671927e733 100644 --- a/api_tests/share/test_share_node.py +++ b/api_tests/share/test_share_node.py @@ -1,9 +1,8 @@ -import json +from unittest.mock import patch + import pytest import responses -from unittest.mock import patch -from api.share.utils import serialize_registration from osf.models import CollectionSubmission, SpamStatus, Outcome from osf.utils.outcomes import ArtifactTypes @@ -21,6 +20,8 @@ from website.project.tasks import on_node_updated from framework.auth.core import Auth +from api.share.utils import shtrove_ingest_url +from ._utils import expect_ingest_request @pytest.mark.django_db @@ -28,9 +29,10 @@ class TestNodeShare: @pytest.fixture(scope='class', autouse=True) - def mock_request_identifier_update(self): + def _patches(self): with patch('osf.models.identifiers.IdentifierMixin.request_identifier_update'): - yield + with patch.object(settings, 'USE_CELERY', False): + yield @pytest.fixture() def user(self): @@ -97,15 +99,12 @@ def registration_outcome(self, registration): return o def test_update_node_share(self, mock_share, node, user): - - on_node_updated(node._id, user._id, False, {'is_public'}) - - assert mock_share.calls[-1].request.headers['Authorization'] == 'Bearer mock-api-token' + with expect_ingest_request(mock_share, node._id): + on_node_updated(node._id, user._id, False, {'is_public'}) def test_update_registration_share(self, mock_share, registration, user): - on_node_updated(registration._id, user._id, False, {'is_public'}) - - assert mock_share.calls[-1].request.headers['Authorization'] == 'Bearer mock-api-token' + with expect_ingest_request(mock_share, registration._id): + on_node_updated(registration._id, user._id, False, {'is_public'}) def test_update_share_correctly_for_projects(self, mock_share, node, user): cases = [{ @@ -126,12 +125,8 @@ def test_update_share_correctly_for_projects(self, mock_share, node, user): for i, case in enumerate(cases): for attr, value in case['attrs'].items(): setattr(node, attr, value) - node.save() - - data = json.loads(mock_share.calls[i].request.body.decode()) - graph = data['data']['attributes']['data']['@graph'] - work_node = next(n for n in graph if n['@type'] == 'project') - assert work_node['is_deleted'] == case['is_deleted'] + with expect_ingest_request(mock_share, node._id, delete=case['is_deleted']): + node.save() def test_update_share_correctly_for_registrations(self, mock_share, registration, user): cases = [{ @@ -149,136 +144,41 @@ def test_update_share_correctly_for_registrations(self, mock_share, registration for i, case in enumerate(cases): for attr, value in case['attrs'].items(): setattr(registration, attr, value) - registration.save() - + with expect_ingest_request(mock_share, registration._id, delete=case['is_deleted']): + registration.save() assert registration.is_registration - data = json.loads(mock_share.calls[i].request.body.decode()) - graph = data['data']['attributes']['data']['@graph'] - payload = next((item for item in graph if 'is_deleted' in item.keys())) - assert payload['is_deleted'] == case['is_deleted'] - - def test_serialize_registration_gets_parent_hierarchy_for_component_registrations(self, project, grandchild_registration): - res = serialize_registration(grandchild_registration) - - graph = res['@graph'] - - # all three registrations are present... - registration_graph_nodes = [n for n in graph if n['@type'] == 'registration'] - assert len(registration_graph_nodes) == 3 - root = next(n for n in registration_graph_nodes if n['title'] == 'Root') - child = next(n for n in registration_graph_nodes if n['title'] == 'Child') - grandchild = next(n for n in registration_graph_nodes if n['title'] == 'Grandchild') - - # ...with the correct 'ispartof' relationships among them (grandchild => child => root) - expected_ispartofs = [ - { - '@type': 'ispartof', - 'subject': {'@id': grandchild['@id'], '@type': 'registration'}, - 'related': {'@id': child['@id'], '@type': 'registration'}, - }, { - '@type': 'ispartof', - 'subject': {'@id': child['@id'], '@type': 'registration'}, - 'related': {'@id': root['@id'], '@type': 'registration'}, - }, - ] - actual_ispartofs = [n for n in graph if n['@type'] == 'ispartof'] - assert len(actual_ispartofs) == 2 - for expected_ispartof in expected_ispartofs: - actual_ispartof = [ - n for n in actual_ispartofs - if expected_ispartof.items() <= n.items() - ] - assert len(actual_ispartof) == 1 - - # ...and each has an identifier - for registration_graph_node in registration_graph_nodes: - workidentifier_graph_nodes = [ - n for n in graph - if n['@type'] == 'workidentifier' - and n['creative_work']['@id'] == registration_graph_node['@id'] - ] - assert len(workidentifier_graph_nodes) == 1 - - def test_serialize_registration_sets_osf_related_resource_types( - self, mock_share, registration, registration_outcome, user - ): - graph = serialize_registration(registration)['@graph'] - registration_graph_node = [n for n in graph if n['@type'] == 'registration'][0] - - expected_resource_types = { - 'data': True, 'papers': True, 'analytic_code': False, 'materials': False, 'supplements': False - } - - assert registration_graph_node['extra']['osf_related_resource_types'] == expected_resource_types def test_update_share_correctly_for_projects_with_qa_tags(self, mock_share, node, user): - node.add_tag(settings.DO_NOT_INDEX_LIST['tags'][0], auth=Auth(user)) - on_node_updated(node._id, user._id, False, {'is_public'}) - data = json.loads(mock_share.calls[-1].request.body.decode()) - graph = data['data']['attributes']['data']['@graph'] - payload = next((item for item in graph if 'is_deleted' in item.keys())) - assert payload['is_deleted'] is True - - node.remove_tag(settings.DO_NOT_INDEX_LIST['tags'][0], auth=Auth(user), save=True) - - data = json.loads(mock_share.calls[-1].request.body.decode()) - graph = data['data']['attributes']['data']['@graph'] - payload = next((item for item in graph if 'is_deleted' in item.keys())) - assert payload['is_deleted'] is False + with expect_ingest_request(mock_share, node._id, delete=True): + node.add_tag(settings.DO_NOT_INDEX_LIST['tags'][0], auth=Auth(user)) + with expect_ingest_request(mock_share, node._id, delete=False): + node.remove_tag(settings.DO_NOT_INDEX_LIST['tags'][0], auth=Auth(user), save=True) def test_update_share_correctly_for_registrations_with_qa_tags(self, mock_share, registration, user): - registration.add_tag(settings.DO_NOT_INDEX_LIST['tags'][0], auth=Auth(user)) - on_node_updated(registration._id, user._id, False, {'is_public'}) - data = json.loads(mock_share.calls[-1].request.body.decode()) - - graph = data['data']['attributes']['data']['@graph'] - payload = next((item for item in graph if 'is_deleted' in item.keys())) - assert payload['is_deleted'] is True - - registration.remove_tag(settings.DO_NOT_INDEX_LIST['tags'][0], auth=Auth(user), save=True) - - data = json.loads(mock_share.calls[-1].request.body.decode()) - graph = data['data']['attributes']['data']['@graph'] - payload = next((item for item in graph if 'is_deleted' in item.keys())) - assert payload['is_deleted'] is False + with expect_ingest_request(mock_share, registration._id, delete=True): + registration.add_tag(settings.DO_NOT_INDEX_LIST['tags'][0], auth=Auth(user)) + with expect_ingest_request(mock_share, registration._id): + registration.remove_tag(settings.DO_NOT_INDEX_LIST['tags'][0], auth=Auth(user), save=True) def test_update_share_correctly_for_projects_with_qa_titles(self, mock_share, node, user): node.title = settings.DO_NOT_INDEX_LIST['titles'][0] + ' arbitary text for test title.' node.save() - on_node_updated(node._id, user._id, False, {'is_public'}) - data = json.loads(mock_share.calls[-1].request.body.decode()) - - graph = data['data']['attributes']['data']['@graph'] - payload = next((item for item in graph if 'is_deleted' in item.keys())) - assert payload['is_deleted'] is True - + with expect_ingest_request(mock_share, node._id, delete=True): + on_node_updated(node._id, user._id, False, {'is_public'}) node.title = 'Not a qa title' - node.save() + with expect_ingest_request(mock_share, node._id): + node.save() assert node.title not in settings.DO_NOT_INDEX_LIST['titles'] - data = json.loads(mock_share.calls[-1].request.body.decode()) - graph = data['data']['attributes']['data']['@graph'] - payload = next((item for item in graph if 'is_deleted' in item.keys())) - assert payload['is_deleted'] is False - def test_update_share_correctly_for_registrations_with_qa_titles(self, mock_share, registration, user): registration.title = settings.DO_NOT_INDEX_LIST['titles'][0] + ' arbitary text for test title.' - registration.save() - - data = json.loads(mock_share.calls[-1].request.body.decode()) - graph = data['data']['attributes']['data']['@graph'] - payload = next((item for item in graph if 'is_deleted' in item.keys())) - assert payload['is_deleted'] is True - + with expect_ingest_request(mock_share, registration._id, delete=True): + registration.save() registration.title = 'Not a qa title' - registration.save() + with expect_ingest_request(mock_share, registration._id): + registration.save() assert registration.title not in settings.DO_NOT_INDEX_LIST['titles'] - data = json.loads(mock_share.calls[-1].request.body.decode()) - graph = data['data']['attributes']['data']['@graph'] - payload = next((item for item in graph if 'is_deleted' in item.keys())) - assert payload['is_deleted'] is False - @responses.activate def test_skips_no_settings(self, node, user): on_node_updated(node._id, user._id, False, {'is_public'}) @@ -286,48 +186,19 @@ def test_skips_no_settings(self, node, user): def test_call_async_update_on_500_retry(self, mock_share, node, user): """This is meant to simulate a temporary outage, so the retry mechanism should kick in and complete it.""" - mock_share.replace(responses.POST, f'{settings.SHARE_URL}api/v2/normalizeddata/', status=500) - mock_share.add(responses.POST, f'{settings.SHARE_URL}api/v2/normalizeddata/', status=200) - - mock_share._calls.reset() # reset after factory calls - on_node_updated(node._id, user._id, False, {'is_public'}) - assert len(mock_share.calls) == 2 - - data = json.loads(mock_share.calls[0].request.body.decode()) - graph = data['data']['attributes']['data']['@graph'] - identifier_node = next(n for n in graph if n['@type'] == 'workidentifier') - assert identifier_node['uri'] == f'{settings.DOMAIN}{node._id}/' - - data = json.loads(mock_share.calls[1].request.body.decode()) - graph = data['data']['attributes']['data']['@graph'] - identifier_node = next(n for n in graph if n['@type'] == 'workidentifier') - assert identifier_node['uri'] == f'{settings.DOMAIN}{node._id}/' + mock_share.replace(responses.POST, shtrove_ingest_url(), status=500) + mock_share.add(responses.POST, shtrove_ingest_url(), status=200) + with expect_ingest_request(mock_share, node._id, count=2): + on_node_updated(node._id, user._id, False, {'is_public'}) def test_call_async_update_on_500_failure(self, mock_share, node, user): """This is meant to simulate a total outage, so the retry mechanism should try X number of times and quit.""" mock_share.assert_all_requests_are_fired = False # allows it to retry indefinitely - mock_share.replace(responses.POST, f'{settings.SHARE_URL}api/v2/normalizeddata/', status=500) - - mock_share._calls.reset() # reset after factory calls - on_node_updated(node._id, user._id, False, {'is_public'}) - - assert len(mock_share.calls) == 6 # first request and five retries - data = json.loads(mock_share.calls[0].request.body.decode()) - graph = data['data']['attributes']['data']['@graph'] - identifier_node = next(n for n in graph if n['@type'] == 'workidentifier') - assert identifier_node['uri'] == f'{settings.DOMAIN}{node._id}/' - - data = json.loads(mock_share.calls[-1].request.body.decode()) - graph = data['data']['attributes']['data']['@graph'] - identifier_node = next(n for n in graph if n['@type'] == 'workidentifier') - assert identifier_node['uri'] == f'{settings.DOMAIN}{node._id}/' + mock_share.replace(responses.POST, shtrove_ingest_url(), status=500) + with expect_ingest_request(mock_share, node._id, count=5): # tries five times + on_node_updated(node._id, user._id, False, {'is_public'}) def test_no_call_async_update_on_400_failure(self, mock_share, node, user): - mock_share.replace(responses.POST, f'{settings.SHARE_URL}api/v2/normalizeddata/', status=400) - - on_node_updated(node._id, user._id, False, {'is_public'}) - - data = json.loads(mock_share.calls[-1].request.body.decode()) - graph = data['data']['attributes']['data']['@graph'] - identifier_node = next(n for n in graph if n['@type'] == 'workidentifier') - assert identifier_node['uri'] == f'{settings.DOMAIN}{node._id}/' + mock_share.replace(responses.POST, shtrove_ingest_url(), status=400) + with expect_ingest_request(mock_share, node._id): + on_node_updated(node._id, user._id, False, {'is_public'}) diff --git a/api_tests/share/test_share_preprint.py b/api_tests/share/test_share_preprint.py index fc65b9cd88a..fcb9e770835 100644 --- a/api_tests/share/test_share_preprint.py +++ b/api_tests/share/test_share_preprint.py @@ -1,18 +1,15 @@ +import contextlib from datetime import datetime -import json -import mock +from unittest import mock + import pytest import responses -from api.share.utils import update_share - -from api_tests.utils import create_test_file - +from api.share.utils import shtrove_ingest_url from framework.auth.core import Auth - +from framework.postcommit_tasks.handlers import postcommit_after_request, postcommit_celery_queue, postcommit_queue from osf.models.spam import SpamStatus from osf.utils.permissions import READ, WRITE, ADMIN - from osf_tests.factories import ( AuthUserFactory, ProjectFactory, @@ -20,14 +17,18 @@ PreprintFactory, PreprintProviderFactory, ) - from website import settings from website.preprints.tasks import on_preprint_updated +from ._utils import expect_ingest_request @pytest.mark.django_db @pytest.mark.enable_enqueue_task class TestPreprintShare: + @pytest.fixture(scope='class', autouse=True) + def _patches(self): + with mock.patch.object(settings, 'USE_CELERY', False): + yield @pytest.fixture def user(self): @@ -56,10 +57,6 @@ def subject(self): def subject_two(self): return SubjectFactory(text='Subject #2') - @pytest.fixture - def file(self, project, user): - return create_test_file(project, user, 'second_place.pdf') - @pytest.fixture def preprint(self, project, user, provider, subject): return PreprintFactory( @@ -72,120 +69,102 @@ def preprint(self, project, user, provider, subject): ) def test_save_unpublished_not_called(self, mock_share, preprint): - mock_share.reset() # if the call is not made responses would raise an assertion error, if not reset. - preprint.save() - assert not len(mock_share.calls) + # expecting no ingest requests (delete or otherwise) + with _expect_preprint_ingest_request(mock_share, preprint, count=0): + preprint.save() - @mock.patch('osf.models.preprint.update_or_enqueue_on_preprint_updated') - def test_save_published_called(self, mock_on_preprint_updated, preprint, user, auth): - preprint.set_published(True, auth=auth, save=True) - assert mock_on_preprint_updated.called + def test_save_published_called(self, mock_share, preprint, user, auth): + with _expect_preprint_ingest_request(mock_share, preprint): + preprint.set_published(True, auth=auth, save=True) # This covers an edge case where a preprint is forced back to unpublished # that it sends the information back to share - @mock.patch('osf.models.preprint.update_or_enqueue_on_preprint_updated') - def test_save_unpublished_called_forced(self, mock_on_preprint_updated, auth, preprint): + def test_save_unpublished_called_forced(self, mock_share, auth, preprint): + with _expect_preprint_ingest_request(mock_share, preprint): + preprint.set_published(True, auth=auth, save=True) + with _expect_preprint_ingest_request(mock_share, preprint, delete=True): + preprint.is_published = False + preprint.save(**{'force_update': True}) + + def test_save_published_subject_change_called(self, mock_share, auth, preprint, subject, subject_two): preprint.set_published(True, auth=auth, save=True) - preprint.is_published = False - preprint.save(**{'force_update': True}) - assert mock_on_preprint_updated.call_count == 2 - - @mock.patch('osf.models.preprint.update_or_enqueue_on_preprint_updated') - def test_save_published_subject_change_called(self, mock_on_preprint_updated, auth, preprint, subject, subject_two): - preprint.is_published = True - preprint.set_subjects([[subject_two._id]], auth=auth) - assert mock_on_preprint_updated.called - call_args, call_kwargs = mock_on_preprint_updated.call_args - assert [subject.id] in mock_on_preprint_updated.call_args[1].values() - - @mock.patch('osf.models.preprint.update_or_enqueue_on_preprint_updated') - def test_save_unpublished_subject_change_not_called(self, mock_on_preprint_updated, auth, preprint, subject_two): - preprint.set_subjects([[subject_two._id]], auth=auth) - assert not mock_on_preprint_updated.called - - def test_send_to_share_is_true(self, mock_share, preprint): - on_preprint_updated(preprint._id) - - data = json.loads(mock_share.calls[-1].request.body.decode()) - assert data['data']['attributes']['data']['@graph'] - assert mock_share.calls[-1].request.headers['Authorization'] == 'Bearer Snowmobiling' - - @mock.patch('osf.models.preprint.update_or_enqueue_on_preprint_updated') - def test_preprint_contributor_changes_updates_preprints_share(self, mock_on_preprint_updated, user, file, auth): - preprint = PreprintFactory(is_published=True, creator=user) - assert mock_on_preprint_updated.call_count == 2 + with _expect_preprint_ingest_request(mock_share, preprint): + preprint.set_subjects([[subject_two._id]], auth=auth) + + def test_save_unpublished_subject_change_not_called(self, mock_share, auth, preprint, subject_two): + with _expect_preprint_ingest_request(mock_share, preprint, delete=True): + preprint.set_subjects([[subject_two._id]], auth=auth) + def test_send_to_share_is_true(self, mock_share, auth, preprint): + preprint.set_published(True, auth=auth, save=True) + with _expect_preprint_ingest_request(mock_share, preprint): + on_preprint_updated(preprint._id, saved_fields=['title']) + + def test_preprint_contributor_changes_updates_preprints_share(self, mock_share, user, auth): + preprint = PreprintFactory(is_published=True, creator=user) + preprint.set_published(True, auth=auth, save=True) user2 = AuthUserFactory() - preprint.primary_file = file - preprint.add_contributor(contributor=user2, auth=auth, save=True) - assert mock_on_preprint_updated.call_count == 5 + with _expect_preprint_ingest_request(mock_share, preprint): + preprint.add_contributor(contributor=user2, auth=auth, save=True) - preprint.move_contributor(contributor=user, index=0, auth=auth, save=True) - assert mock_on_preprint_updated.call_count == 7 + with _expect_preprint_ingest_request(mock_share, preprint): + preprint.move_contributor(contributor=user, index=0, auth=auth, save=True) data = [{'id': user._id, 'permissions': ADMIN, 'visible': True}, {'id': user2._id, 'permissions': WRITE, 'visible': False}] - preprint.manage_contributors(data, auth=auth, save=True) - assert mock_on_preprint_updated.call_count == 9 - - preprint.update_contributor(user2, READ, True, auth=auth, save=True) - assert mock_on_preprint_updated.call_count == 11 - - preprint.remove_contributor(contributor=user2, auth=auth) - assert mock_on_preprint_updated.call_count == 13 - - def test_call_async_update_on_500_failure(self, mock_share, preprint): - mock_share.replace(responses.POST, f'{settings.SHARE_URL}api/v2/normalizeddata/', status=500) + with _expect_preprint_ingest_request(mock_share, preprint): + preprint.manage_contributors(data, auth=auth, save=True) - mock_share._calls.reset() # reset after factory calls - update_share(preprint) + with _expect_preprint_ingest_request(mock_share, preprint): + preprint.update_contributor(user2, READ, True, auth=auth, save=True) - assert len(mock_share.calls) == 6 # first request and five retries - data = json.loads(mock_share.calls[0].request.body.decode()) - graph = data['data']['attributes']['data']['@graph'] - data = next(data for data in graph if data['@type'] == 'preprint') - assert data['title'] == preprint.title + with _expect_preprint_ingest_request(mock_share, preprint): + preprint.remove_contributor(contributor=user2, auth=auth) - data = json.loads(mock_share.calls[-1].request.body.decode()) - graph = data['data']['attributes']['data']['@graph'] - data = next(data for data in graph if data['@type'] == 'preprint') - assert data['title'] == preprint.title - - def test_no_call_async_update_on_400_failure(self, mock_share, preprint): - mock_share.replace(responses.POST, f'{settings.SHARE_URL}api/v2/normalizeddata/', status=400) - - mock_share._calls.reset() # reset after factory calls - update_share(preprint) + def test_call_async_update_on_500_failure(self, mock_share, preprint, auth): + mock_share.replace(responses.POST, shtrove_ingest_url(), status=500) + preprint.set_published(True, auth=auth, save=True) + with _expect_preprint_ingest_request(mock_share, preprint, count=5): + preprint.update_search() - assert len(mock_share.calls) == 1 - data = json.loads(mock_share.calls[0].request.body.decode()) - graph = data['data']['attributes']['data']['@graph'] - data = next(data for data in graph if data['@type'] == 'preprint') - assert data['title'] == preprint.title + def test_no_call_async_update_on_400_failure(self, mock_share, preprint, auth): + mock_share.replace(responses.POST, shtrove_ingest_url(), status=400) + preprint.set_published(True, auth=auth, save=True) + with _expect_preprint_ingest_request(mock_share, preprint, count=1): + preprint.update_search() def test_delete_from_share(self, mock_share): preprint = PreprintFactory() - update_share(preprint) - - data = json.loads(mock_share.calls[-1].request.body.decode()) - graph = data['data']['attributes']['data']['@graph'] - share_preprint = next(n for n in graph if n['@type'] == 'preprint') - assert not share_preprint['is_deleted'] - + with _expect_preprint_ingest_request(mock_share, preprint): + preprint.update_search() preprint.date_withdrawn = datetime.now() - update_share(preprint) - - data = json.loads(mock_share.calls[-1].request.body.decode()) - graph = data['data']['attributes']['data']['@graph'] - share_preprint = next(n for n in graph if n['@type'] == 'preprint') - assert not share_preprint['is_deleted'] - + preprint.save() + with _expect_preprint_ingest_request(mock_share, preprint): + preprint.update_search() preprint.spam_status = SpamStatus.SPAM - update_share(preprint) - - data = json.loads(mock_share.calls[-1].request.body.decode()) - graph = data['data']['attributes']['data']['@graph'] - share_preprint = next(n for n in graph if n['@type'] == 'preprint') - assert share_preprint['is_deleted'] + preprint.save() + with _expect_preprint_ingest_request(mock_share, preprint, delete=True): + preprint.update_search() + + +@contextlib.contextmanager +def _expect_preprint_ingest_request(mock_share, preprint, *, delete=False, count=1): + # same as expect_ingest_request, but with convenience for preprint specifics + # and postcommit-task handling (so on_preprint_updated actually runs) + with expect_ingest_request( + mock_share, + preprint._id, + token=preprint.provider.access_token, + delete=delete, + count=count, + ): + # clear out postcommit tasks from factories + postcommit_queue().clear() + postcommit_celery_queue().clear() + yield + _mock_request = mock.Mock() + _mock_request.status_code = 200 + # run postcommit tasks (specifically care about on_preprint_updated) + postcommit_after_request(_mock_request) diff --git a/conftest.py b/conftest.py index 24afe25f0b4..54f195fe3c7 100644 --- a/conftest.py +++ b/conftest.py @@ -12,6 +12,7 @@ import responses import xml.etree.ElementTree as ET +from api.share.utils import shtrove_ingest_url from framework.celery_tasks import app as celery_app from website import settings as website_settings @@ -153,8 +154,10 @@ def teardown_es(): def mock_share(): with mock.patch('api.share.utils.settings.SHARE_ENABLED', True): with mock.patch('api.share.utils.settings.SHARE_API_TOKEN', 'mock-api-token'): - with responses.RequestsMock(assert_all_requests_are_fired=True) as rsps: - rsps.add(responses.POST, f'{website_settings.SHARE_URL}api/v2/normalizeddata/', status=200) + with responses.RequestsMock(assert_all_requests_are_fired=False) as rsps: + _ingest_url = shtrove_ingest_url() + rsps.add(responses.POST, _ingest_url, status=200) + rsps.add(responses.DELETE, _ingest_url, status=200) yield rsps diff --git a/framework/celery_tasks/handlers.py b/framework/celery_tasks/handlers.py index d9f616d3813..9cf85045d04 100644 --- a/framework/celery_tasks/handlers.py +++ b/framework/celery_tasks/handlers.py @@ -39,7 +39,7 @@ def celery_teardown_request(error=None): group(queue()).apply_async() else: for task in queue(): - task() + task.apply() def get_task_from_queue(name, predicate): @@ -66,7 +66,7 @@ def _enqueue_task(signature): context_stack.top is None and getattr(api_globals, 'request', None) is None ): # Not in a request context - signature() + signature.apply() else: if signature not in queue(): queue().append(signature) diff --git a/osf/metadata/tools.py b/osf/metadata/tools.py index 0f787f61fa7..2cd4699413f 100644 --- a/osf/metadata/tools.py +++ b/osf/metadata/tools.py @@ -2,11 +2,8 @@ ''' import typing -import requests -from website import settings as website_settings - from osf.models.base import coerce_guid -from osf.metadata.osf_gathering import pls_get_magic_metadata_basket, osf_iri +from osf.metadata.osf_gathering import pls_get_magic_metadata_basket from osf.metadata.serializers import get_metadata_serializer @@ -44,54 +41,3 @@ def pls_gather_metadata_file(osf_item, format_key, serializer_config=None) -> Se filename=serializer.filename_for_itemid(osfguid._id), serialized_metadata=serializer.serialize(), ) - - -def pls_send_trove_indexcard(osf_item): - _iri = osf_iri(osf_item) - if not _iri: - raise ValueError(f'could not get iri for {osf_item}') - _metadata_record = pls_gather_metadata_file(osf_item, 'turtle') - return requests.post( - _shtrove_ingest_url(), - params={ - 'focus_iri': _iri, - 'record_identifier': _shtrove_record_identifier(osf_item), - }, - headers={ - 'Content-Type': _metadata_record.mediatype, - **_shtrove_auth_headers(osf_item), - }, - data=_metadata_record.serialized_metadata, - ) - - -def pls_delete_trove_indexcard(osf_item): - return requests.delete( - _shtrove_ingest_url(), - params={ - 'record_identifier': _shtrove_record_identifier(osf_item), - }, - headers=_shtrove_auth_headers(osf_item), - ) - - -def _shtrove_record_identifier(osf_item): - return osf_item.guids.values_list('_id', flat=True).first() - - -def _shtrove_ingest_url(): - return f'{website_settings.SHARE_URL}api/v3/ingest' - - -def _shtrove_auth_headers(osf_item): - _nonfile_item = ( - osf_item.target - if hasattr(osf_item, 'target') - else osf_item - ) - _access_token = ( - _nonfile_item.provider.access_token - if _nonfile_item.provider and _nonfile_item.provider.access_token - else website_settings.SHARE_API_TOKEN - ) - return {'Authorization': f'Bearer {_access_token}'} diff --git a/osf/models/mixins.py b/osf/models/mixins.py index 7f947130d2f..e2201777232 100644 --- a/osf/models/mixins.py +++ b/osf/models/mixins.py @@ -1136,6 +1136,7 @@ def set_subjects(self, new_subjects, auth, add_log=True): self.add_subjects_log(old_subjects, auth) self.save() + self.update_or_enqueue_on_resource_updated(auth.user._id, first_save=False, saved_fields=['subjects']) def set_subjects_from_relationships(self, subjects_list, auth, add_log=True): """ Helper for setting M2M subjects field from list of flattened subjects received from UI. @@ -1161,6 +1162,7 @@ def set_subjects_from_relationships(self, subjects_list, auth, add_log=True): self.add_subjects_log(old_subjects, auth) self.save() + self.update_or_enqueue_on_resource_updated(auth.user._id, saved_fields=['subjects']) def map_subjects_between_providers(self, old_provider, new_provider, auth=None): """ diff --git a/website/preprints/tasks.py b/website/preprints/tasks.py index 87f62c569a1..0be65e3ad05 100644 --- a/website/preprints/tasks.py +++ b/website/preprints/tasks.py @@ -15,6 +15,8 @@ def on_preprint_updated(preprint_id, saved_fields=None, **kwargs): # transactions are implemented in View and Task application layers. from osf.models import Preprint preprint = Preprint.load(preprint_id) + if not preprint: + return need_update = bool(preprint.SEARCH_UPDATE_FIELDS.intersection(saved_fields or {})) if need_update: