Skip to content

Commit 0e892ed

Browse files
author
gitstart_bot
committed
chore: implement ip restriction on api endpoints
1 parent 0c74869 commit 0e892ed

File tree

6 files changed

+90
-23
lines changed

6 files changed

+90
-23
lines changed

server/covmanager/views.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
from django.shortcuts import get_object_or_404, redirect, render
1111
from django.views.decorators.csrf import csrf_exempt
1212
from rest_framework import filters, mixins, viewsets
13-
from rest_framework.authentication import SessionAuthentication, TokenAuthentication
13+
from rest_framework.authentication import SessionAuthentication
1414

1515
from crashmanager.models import Tool
1616
from server.views import JsonQueryFilterBackend, SimpleQueryFilterBackend
17+
from server.utils import IPRestrictedTokenAuthentication
1718

1819
from .models import Collection, Report, ReportConfiguration, ReportSummary, Repository
1920
from .serializers import (
@@ -705,7 +706,7 @@ class CollectionViewSet(
705706
API endpoint that allows adding/viewing Collections
706707
"""
707708

708-
authentication_classes = (TokenAuthentication, SessionAuthentication)
709+
authentication_classes = (IPRestrictedTokenAuthentication, SessionAuthentication)
709710
queryset = Collection.objects.all()
710711
serializer_class = CollectionSerializer
711712
paginate_by_param = "limit"
@@ -754,7 +755,7 @@ class ReportViewSet(
754755
API endpoint that allows viewing Reports
755756
"""
756757

757-
authentication_classes = (TokenAuthentication, SessionAuthentication)
758+
authentication_classes = (IPRestrictedTokenAuthentication, SessionAuthentication)
758759
queryset = Report.objects.all()
759760
serializer_class = ReportSerializer
760761
paginate_by_param = "limit"
@@ -778,7 +779,7 @@ class RepositoryViewSet(
778779
API endpoint that allows viewing Repositories
779780
"""
780781

781-
authentication_classes = (TokenAuthentication,)
782+
authentication_classes = (IPRestrictedTokenAuthentication,)
782783
queryset = Repository.objects.all()
783784
serializer_class = RepositorySerializer
784785
filter_backends = [JsonQueryFilterBackend]
@@ -819,7 +820,7 @@ class ReportConfigurationViewSet(
819820
API endpoint that allows adding/updating/viewing Report Configurations
820821
"""
821822

822-
authentication_classes = (TokenAuthentication, SessionAuthentication)
823+
authentication_classes = (IPRestrictedTokenAuthentication, SessionAuthentication)
823824
queryset = ReportConfiguration.objects.all()
824825
serializer_class = ReportConfigurationSerializer
825826
filter_backends = [
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import pytest
2+
from django.conf import settings
3+
from django.contrib.auth.models import User
4+
from rest_framework.authtoken.models import Token
5+
import requests
6+
7+
@pytest.fixture
8+
def allowed_ip():
9+
return "127.0.0.1"
10+
11+
@pytest.fixture
12+
def blocked_ip():
13+
return "203.0.113.10"
14+
15+
@pytest.fixture(autouse=True)
16+
def override_settings(allowed_ip):
17+
settings.ALLOWED_IPS = [allowed_ip]
18+
19+
@pytest.mark.django_db
20+
def test_allowed_ip_can_authenticate(api_client, user_normal, allowed_ip):
21+
"""Ensure authentication works for an allowed IP"""
22+
response = api_client.get("/crashmanager/rest/crashes/", REMOTE_ADDR=allowed_ip)
23+
assert response.status_code == 200
24+
25+
@pytest.mark.django_db
26+
def test_blocked_ip_is_rejected(api_client, allowed_ip, blocked_ip):
27+
"""Ensure authentication fails for a blocked IP"""
28+
response = api_client.get("/crashmanager/rest/crashes/", REMOTE_ADDR=blocked_ip)
29+
assert response.status_code == 403
30+
assert response.json()['detail'] == 'IP address restricted. Access denied.'
31+

server/crashmanager/views.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from django.views.generic.list import ListView
1919
from notifications.models import Notification
2020
from rest_framework import mixins, status, viewsets
21-
from rest_framework.authentication import SessionAuthentication, TokenAuthentication
21+
from rest_framework.authentication import SessionAuthentication
2222
from rest_framework.decorators import action
2323
from rest_framework.exceptions import MethodNotAllowed, ValidationError
2424
from rest_framework.filters import BaseFilterBackend, OrderingFilter
@@ -28,6 +28,7 @@
2828
from FTB.ProgramConfiguration import ProgramConfiguration
2929
from FTB.Signatures.CrashInfo import CrashInfo
3030
from server.auth import CheckAppPermission
31+
from server.utils import IPRestrictedTokenAuthentication
3132

3233
from .forms import (
3334
BugzillaTemplateBugForm,
@@ -1017,7 +1018,7 @@ class CrashEntryViewSet(
10171018
):
10181019
"""API endpoint that allows adding/viewing CrashEntries"""
10191020

1020-
authentication_classes = (TokenAuthentication, SessionAuthentication)
1021+
authentication_classes = (IPRestrictedTokenAuthentication, SessionAuthentication)
10211022
queryset = CrashEntry.objects.all().select_related(
10221023
"product", "platform", "os", "client", "tool", "testcase"
10231024
)
@@ -1143,7 +1144,7 @@ class BucketViewSet(
11431144
):
11441145
"""API endpoint that allows viewing Buckets"""
11451146

1146-
authentication_classes = (TokenAuthentication, SessionAuthentication)
1147+
authentication_classes = (IPRestrictedTokenAuthentication, SessionAuthentication)
11471148
queryset = Bucket.objects.all().select_related("bug", "bug__externalType")
11481149
serializer_class = BucketSerializer
11491150
filter_backends = [
@@ -1430,7 +1431,7 @@ class BugProviderViewSet(mixins.ListModelMixin, viewsets.GenericViewSet):
14301431
API endpoint that allows listing BugProviders
14311432
"""
14321433

1433-
authentication_classes = (TokenAuthentication, SessionAuthentication)
1434+
authentication_classes = (IPRestrictedTokenAuthentication, SessionAuthentication)
14341435
queryset = BugProvider.objects.all()
14351436
serializer_class = BugProviderSerializer
14361437

@@ -1442,7 +1443,7 @@ class BugzillaTemplateViewSet(
14421443
API endpoint that allows viewing BugzillaTemplates
14431444
"""
14441445

1445-
authentication_classes = (TokenAuthentication, SessionAuthentication)
1446+
authentication_classes = (IPRestrictedTokenAuthentication, SessionAuthentication)
14461447
queryset = BugzillaTemplate.objects.all()
14471448
serializer_class = BugzillaTemplateSerializer
14481449

@@ -1452,7 +1453,7 @@ class NotificationViewSet(mixins.ListModelMixin, viewsets.GenericViewSet):
14521453
API endpoint that allows listing unread Notifications
14531454
"""
14541455

1455-
authentication_classes = (TokenAuthentication, SessionAuthentication)
1456+
authentication_classes = (IPRestrictedTokenAuthentication, SessionAuthentication)
14561457
serializer_class = NotificationSerializer
14571458
filter_backends = [
14581459
JsonQueryFilterBackend,
@@ -1545,7 +1546,7 @@ def get_query_obj(obj, key=None):
15451546

15461547

15471548
class AbstractDownloadView(APIView):
1548-
authentication_classes = (TokenAuthentication, SessionAuthentication)
1549+
authentication_classes = (IPRestrictedTokenAuthentication, SessionAuthentication)
15491550
permission_classes = (CheckAppPermission,)
15501551

15511552
def response(self, file_path, filename, content_type="application/octet-stream"):
@@ -1765,7 +1766,7 @@ class CrashStatsViewSet(viewsets.GenericViewSet):
17651766
API endpoint that allows retrieving CrashManager statistics
17661767
"""
17671768

1768-
authentication_classes = (TokenAuthentication, SessionAuthentication)
1769+
authentication_classes = (IPRestrictedTokenAuthentication, SessionAuthentication)
17691770
queryset = CrashEntry.objects.all()
17701771
filter_backends = [
17711772
ToolFilterCrashesBackend,

server/ec2spotmanager/views.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414
from django.shortcuts import get_object_or_404, redirect, render
1515
from django.utils.timezone import now, timedelta
1616
from rest_framework import mixins, serializers, status, viewsets
17-
from rest_framework.authentication import SessionAuthentication, TokenAuthentication
17+
from rest_framework.authentication import SessionAuthentication
1818
from rest_framework.response import Response
1919
from rest_framework.views import APIView
2020

2121
from server.auth import CheckAppPermission
22+
from server.utils import IPRestrictedTokenAuthentication
2223

2324
from .CloudProvider.CloudProvider import (
2425
INSTANCE_STATE,
@@ -917,7 +918,7 @@ def get_labels(self, pool, entries):
917918

918919

919920
class MachineStatusViewSet(APIView):
920-
authentication_classes = (TokenAuthentication,)
921+
authentication_classes = (IPRestrictedTokenAuthentication,)
921922

922923
def get(self, request, *args, **kwargs):
923924
result = {}
@@ -947,7 +948,7 @@ class PoolConfigurationViewSet(
947948
API endpoint that allows viewing PoolConfigurations
948949
"""
949950

950-
authentication_classes = (TokenAuthentication,)
951+
authentication_classes = (IPRestrictedTokenAuthentication,)
951952
permission_classes = (CheckAppPermission,)
952953
queryset = PoolConfiguration.objects.all()
953954
serializer_class = PoolConfigurationSerializer
@@ -964,7 +965,7 @@ def retrieve(self, request, *args, **kwds):
964965

965966

966967
class PoolCycleView(APIView):
967-
authentication_classes = (TokenAuthentication,)
968+
authentication_classes = (IPRestrictedTokenAuthentication,)
968969
permission_classes = (CheckAppPermission,)
969970

970971
def post(self, request, poolid, format=None):
@@ -982,7 +983,7 @@ def post(self, request, poolid, format=None):
982983

983984

984985
class PoolEnableView(APIView):
985-
authentication_classes = (TokenAuthentication,)
986+
authentication_classes = (IPRestrictedTokenAuthentication,)
986987
permission_classes = (CheckAppPermission,)
987988

988989
def post(self, request, poolid, format=None):
@@ -1002,7 +1003,7 @@ def post(self, request, poolid, format=None):
10021003

10031004

10041005
class PoolDisableView(APIView):
1005-
authentication_classes = (TokenAuthentication,)
1006+
authentication_classes = (IPRestrictedTokenAuthentication,)
10061007
permission_classes = (CheckAppPermission,)
10071008

10081009
def post(self, request, poolid, format=None):

server/server/utils.py

+32
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44

55
import redis
66

7+
from rest_framework.authentication import TokenAuthentication
8+
from rest_framework.exceptions import PermissionDenied
9+
from django.conf import settings
10+
711
LOG = logging.getLogger("fuzzmanager.utils")
812

913

@@ -64,3 +68,31 @@ def release(self):
6468
"Failed to release lock: %s(%s) != %s", self.name, self.unique_id, existing
6569
)
6670
return False
71+
72+
def get_client_ip(request):
73+
"""
74+
Extracts the client IP address from request headers.
75+
"""
76+
x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR")
77+
if x_forwarded_for:
78+
ip = x_forwarded_for.split(",")[0].strip()
79+
else:
80+
ip = request.META.get("REMOTE_ADDR")
81+
82+
return ip
83+
84+
class IPRestrictedTokenAuthentication(TokenAuthentication):
85+
def authenticate(self, request):
86+
if self.is_ip_restricted(request):
87+
raise PermissionDenied("IP address restricted. Access denied.")
88+
89+
return super().authenticate(request)
90+
91+
def is_ip_restricted(self, request):
92+
allowed_ips = set(getattr(settings, "ALLOWED_IPS", []))
93+
client_ip = get_client_ip(request)
94+
if client_ip not in allowed_ips:
95+
LOG.warning(f"IP address restricted: {client_ip}")
96+
return True
97+
98+
return False

server/taskmanager/views.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
from django.shortcuts import get_object_or_404, redirect, render
66
from django.utils import timezone
77
from rest_framework import mixins, status, viewsets
8-
from rest_framework.authentication import SessionAuthentication, TokenAuthentication
8+
from rest_framework.authentication import SessionAuthentication
99
from rest_framework.decorators import action
1010
from rest_framework.filters import OrderingFilter
1111
from rest_framework.response import Response
1212

1313
from server.auth import CheckAppPermission
1414
from server.views import JsonQueryFilterBackend, SimpleQueryFilterBackend
15+
from server.utils import IPRestrictedTokenAuthentication
1516

1617
from .models import Pool, Task
1718
from .serializers import (
@@ -54,7 +55,7 @@ class PoolViewSet(viewsets.ReadOnlyModelViewSet):
5455
API endpoint that allows viewing Pools
5556
"""
5657

57-
authentication_classes = (TokenAuthentication, SessionAuthentication)
58+
authentication_classes = (IPRestrictedTokenAuthentication, SessionAuthentication)
5859
permission_classes = (CheckAppPermission,)
5960
queryset = Pool.objects.all()
6061
serializer_class = PoolSerializer
@@ -82,7 +83,7 @@ class TaskViewSet(
8283
API endpoint that allows viewing Tasks
8384
"""
8485

85-
authentication_classes = (TokenAuthentication, SessionAuthentication)
86+
authentication_classes = (IPRestrictedTokenAuthentication, SessionAuthentication)
8687
permission_classes = (CheckAppPermission,)
8788
queryset = Task.objects.all()
8889
serializer_class = TaskSerializer
@@ -103,7 +104,7 @@ def get_serializer(self, *args, **kwds):
103104
return super().get_serializer(*args, **kwds)
104105

105106
@action(
106-
detail=False, methods=["post"], authentication_classes=(TokenAuthentication,)
107+
detail=False, methods=["post"], authentication_classes=(IPRestrictedTokenAuthentication,)
107108
)
108109
def update_status(self, request):
109110
if set(request.data.keys()) != {"client", "status_data"}:

0 commit comments

Comments
 (0)