diff --git a/ecommerce/extensions/payment/tests/views/test_sdn.py b/ecommerce/extensions/payment/tests/views/test_sdn.py index c89254bdbd3..a1fca2ad4ed 100644 --- a/ecommerce/extensions/payment/tests/views/test_sdn.py +++ b/ecommerce/extensions/payment/tests/views/test_sdn.py @@ -5,6 +5,7 @@ from django.urls import reverse from requests.exceptions import HTTPError +from ecommerce.extensions.api.tests.test_authentication import AccessTokenMixin from ecommerce.extensions.payment.models import SDNCheckFailure from ecommerce.tests.testcases import TestCase @@ -19,13 +20,14 @@ def test_sdn_logout_context(self): self.assertEqual(response.context['logout_url'], logout_url) -class SDNCheckViewTests(TestCase): +class SDNCheckViewTests(AccessTokenMixin, TestCase): sdn_check_path = reverse('sdn:check') def setUp(self): super().setUp() - self.user = self.create_user() + self.user = self.create_user(is_staff=True) self.client.login(username=self.user.username, password=self.password) + self.token = self.generate_jwt_token_header(self.user) self.post_params = { 'lms_user_id': 1337, 'name': 'Bowser, King of the Koopas', @@ -34,7 +36,7 @@ def setUp(self): } def test_sdn_check_missing_args(self): - response = self.client.post(self.sdn_check_path) + response = self.client.post(self.sdn_check_path, HTTP_AUTHORIZATION=self.token) assert response.status_code == 400 @mock.patch('ecommerce.extensions.payment.views.sdn.checkSDNFallback') @@ -42,7 +44,7 @@ def test_sdn_check_missing_args(self): def test_sdn_check_search_fails_uses_fallback(self, mock_search, mock_fallback): mock_search.side_effect = [HTTPError] mock_fallback.return_value = 0 - response = self.client.post(self.sdn_check_path, data=self.post_params) + response = self.client.post(self.sdn_check_path, data=self.post_params, HTTP_AUTHORIZATION=self.token) assert response.status_code == 200 assert response.json()['hit_count'] == 0 @@ -50,7 +52,7 @@ def test_sdn_check_search_fails_uses_fallback(self, mock_search, mock_fallback): @mock.patch('ecommerce.extensions.payment.views.sdn.SDNClient.search') def test_sdn_check_search_succeeds(self, mock_search, mock_fallback): mock_search.return_value = {'total': 4} - response = self.client.post(self.sdn_check_path, data=self.post_params) + response = self.client.post(self.sdn_check_path, data=self.post_params, HTTP_AUTHORIZATION=self.token) assert response.status_code == 200 assert response.json()['hit_count'] == 4 assert response.json()['sdn_response'] == {'total': 4} @@ -64,6 +66,7 @@ def setUp(self): super().setUp() self.user = self.create_user(is_staff=True) self.client.login(username=self.user.username, password=self.password) + self.token = self.generate_jwt_token_header(self.user) self.post_params = { 'full_name': 'Princess Peach', 'username': 'toadstool_is_cool', @@ -77,25 +80,29 @@ def setUp(self): def test_non_staff_cannot_access_endpoint(self): self.user.is_staff = False self.user.save() - response = self.client.post(self.sdn_check_path, data=self.post_params, content_type='application/json') + response = self.client.post(self.sdn_check_path, data=self.post_params, content_type='application/json', + HTTP_AUTHORIZATION=self.token) assert response.status_code == 403 def test_missing_payload_arg_400(self): del self.post_params['full_name'] - response = self.client.post(self.sdn_check_path, data=self.post_params, content_type='application/json') + response = self.client.post(self.sdn_check_path, data=self.post_params, content_type='application/json', + HTTP_AUTHORIZATION=self.token) assert response.status_code == 400 def test_sdn_response_response_missing_required_field_400(self): del self.post_params['sdn_check_response']['total'] assert 'sdn_check_response' in self.post_params # so it's clear we deleted the sub dict's key - response = self.client.post(self.sdn_check_path, data=self.post_params, content_type='application/json') + response = self.client.post(self.sdn_check_path, data=self.post_params, content_type='application/json', + HTTP_AUTHORIZATION=self.token) assert response.status_code == 400 def test_happy_path_create(self): assert SDNCheckFailure.objects.count() == 0 json_payload = json.dumps(self.post_params) - response = self.client.post(self.sdn_check_path, data=json_payload, content_type='application/json') + response = self.client.post(self.sdn_check_path, data=json_payload, content_type='application/json', + HTTP_AUTHORIZATION=self.token) assert response.status_code == 201 assert SDNCheckFailure.objects.count() == 1 diff --git a/ecommerce/extensions/payment/views/sdn.py b/ecommerce/extensions/payment/views/sdn.py index 1b450bc9c4d..117a534580f 100644 --- a/ecommerce/extensions/payment/views/sdn.py +++ b/ecommerce/extensions/payment/views/sdn.py @@ -1,10 +1,9 @@ import logging from django.conf import settings -from django.contrib.auth.decorators import login_required from django.http import JsonResponse -from django.utils.decorators import method_decorator -from django.views.generic import TemplateView, View +from django.views.generic import TemplateView +from edx_rest_framework_extensions.auth.jwt.authentication import JwtAuthentication from requests.exceptions import HTTPError, Timeout from rest_framework import status, views from rest_framework.permissions import IsAdminUser, IsAuthenticated @@ -21,7 +20,8 @@ class SDNCheckFailureView(views.APIView): REST API for SDNCheckFailure class. """ http_method_names = ['post', 'options'] - permission_classes = [IsAuthenticated, IsAdminUser] + authentication_classes = (JwtAuthentication,) + permission_classes = (IsAuthenticated, IsAdminUser) serializer_class = SDNCheckFailureSerializer def _validate_arguments(self, payload): @@ -83,7 +83,7 @@ def get_context_data(self, **kwargs): return context -class SDNCheckView(View): +class SDNCheckView(views.APIView): """ View for external services to use to run SDN checks against. @@ -91,8 +91,9 @@ class SDNCheckView(View): not called during a normal checkout flow (as of 6/8/2023). """ http_method_names = ['post', 'options'] + authentication_classes = (JwtAuthentication,) + permission_classes = (IsAuthenticated, IsAdminUser) - @method_decorator(login_required) def post(self, request): """ Use data provided to check against SDN list.