Skip to content

Commit

Permalink
fix protected tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mayofaulkner committed Mar 21, 2024
1 parent 61193db commit 670ec9c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
13 changes: 8 additions & 5 deletions alyx/data/tests_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,22 +774,25 @@ def test_check_protected(self):
'name': 'drb1', # this is the repository name
}

d = self.client.post(reverse('register-file'), data)
d = self.ar(self.client.post(reverse('register-file'), data), 201)

# Check the same dataset to see if it is protected, should be unprotected
# and get a status 200 respons
_ = data.pop('name')
r = self.client.post(reverse('check-protected'), data)
self.assertEqual(r['status'], 200)

r = self.ar(self.client.get(reverse('check-protected'), data=data,
content_type='application/json'), 200)
self.assertEqual(r['status_code'], 200)

# add protected tag to the first dataset
dataset1 = Dataset.objects.get(pk=d[0]['id'])
tag1 = Tag.objects.get(name='tag1')
dataset1.tags.add(tag1)

# Check the same dataset to see if it is protected
r = self.client.post(reverse('check-protected'), data)
self.assertEqual(r['status'], 403)
r = self.ar(self.client.get(reverse('check-protected'), data=data,
content_type='application/json'), 200)
self.assertEqual(r['status_code'], 403)
self.assertEqual(r['error'], 'One or more datasets is protected')

def test_revisions(self):
Expand Down
8 changes: 5 additions & 3 deletions alyx/data/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,13 +353,15 @@ def list(self, request):
- Status 200 is none of the datasets are protected
"""

user = request.data.get('created_by', None)
req = request.GET.dict() if len(request.data) == 0 else request.data

user = req.get('created_by', None)
if user:
user = get_user_model().objects.get(username=user)
else:
user = request.user

rel_dir_path = request.data.get('path', '')
rel_dir_path = req.get('path', '')
if not rel_dir_path:
raise ValueError("The path argument is required.")

Expand All @@ -368,7 +370,7 @@ def list(self, request):
rel_dir_path = rel_dir_path.replace('//', '/')
subject, date, session_number = _parse_path(rel_dir_path)

filenames = request.data.get('filenames', ())
filenames = req.get('filenames', ())
if isinstance(filenames, str):
filenames = filenames.split(',')

Expand Down

0 comments on commit 670ec9c

Please sign in to comment.