Skip to content

Commit

Permalink
chore: make UpdateBotInput.id as a required field and write a test fo…
Browse files Browse the repository at this point in the history
…r that
  • Loading branch information
ipeterov committed Jan 20, 2025
1 parent ecab5b2 commit 03e50ab
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 61 deletions.
77 changes: 31 additions & 46 deletions aiarena/core/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pytest_django.live_server_helper import LiveServer

from aiarena.core.models import WebsiteUser
from aiarena.core.utils import dict_camel_case, dict_get
from aiarena.core.utils import dict_get


class BrowserHelper:
Expand All @@ -34,43 +34,68 @@ class GraphQLTest:

# GraphQL mutation query used by `mutate` method.
mutation = None
# The mutation's name so that we can get the content from the JSON response
mutation_name = None
# stored django test client for the last performed query, makes it possible
# to check request/response attributes (like auth session)
last_query_client = None

def mutate(
self,
expected_status: int = 200,
variables: dict = None,
expected_status: int = HTTPStatus.OK.value,
mutation: str | None = None,
login_user: WebsiteUser | None = None,
expected_validation_errors=None,
**kwargs,
) -> dict | None:
"""
Perform GraphQL mutation.
"""
kwargs = dict_camel_case(kwargs)
return self.query(
expected_validation_errors = expected_validation_errors or {}
response_data = self.query(
query=(mutation or self.mutation),
expected_status=expected_status,
login_user=login_user,
variables=kwargs,
variables=variables,
**kwargs,
)

mutation_data = response_data[self.mutation_name] or {}
actual_errors = {error["field"]: error["messages"] for error in mutation_data.get("errors", [])}
assert (
actual_errors == expected_validation_errors
), f"Unexpected validation errors: {actual_errors}, expected {expected_validation_errors}"

return response_data

def query(
self,
query: str,
expected_status: int = HTTPStatus.OK.value,
expected_errors: list = None,
variables: dict | None = None,
login_user: WebsiteUser | None = None,
) -> dict | None:
"""Perform GraphQL query."""
expected_errors = expected_errors or []

self.last_query_client = self.client(login_user)

response = self.do_post(self.last_query_client, query, variables)

assert response.status_code == expected_status, (
f"Unexpected response status code: {response.status_code}\n" f"Response content: {response.content}"
)

content = json.loads(response.content)
error_messages = [error["message"] for error in content.get("errors", [])]
assert set(error_messages) == set(
expected_errors
), f"Unexpected errors: {error_messages}\nResponse content: {content}"

if response.status_code == HTTPStatus.OK.value:
return json.loads(response.content)
return json.loads(response.content)["data"]

@classmethod
def client(
Expand All @@ -86,46 +111,6 @@ def client(

return client

@classmethod
def assert_graphql_error(cls, response: dict, message: str):
"""
Check GQL response contains given error message.
"""
assert "errors" in response
messages = {error["message"] for error in response["errors"]}
if message not in messages:
msg = f'Expected to find "{message}" error, but found: {messages}'
raise AssertionError(msg)

@classmethod
def assert_graphql_error_like(cls, response: dict, substring: str):
"""
Check GQL response contains given error message.
"""
assert "errors" in response
messages = {error["message"] for error in response["errors"]}
for message in messages:
if substring in message:
return
msg = '"{}" wasn\'t found in error messages:\n{}'.format(
substring,
"\n".join(f" - {m}" for m in messages),
)
raise AssertionError(msg)

@classmethod
def assert_no_graphql_errors(cls, response: dict):
"""
Check GQL response does not contain any errors.
"""
err = {error["message"] for error in response.get("errors", [])}
msg = f"Expected no errors, but got: {err}"
assert not err, msg

@classmethod
def assert_access_denied(cls, response: dict):
cls.assert_graphql_error(response, "Access denied.")

@classmethod
def do_post(
cls,
Expand Down
3 changes: 3 additions & 0 deletions aiarena/graphql/mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ class UpdateBotInput(CleanedInputType):
bot_data_enabled = graphene.Boolean()
bot_data_publicly_downloadable = graphene.Boolean()

class Meta:
required_fields = ["id"]


class UpdateBot(CleanedInputMutation):
bot = graphene.Field(BotType)
Expand Down
59 changes: 44 additions & 15 deletions aiarena/graphql/tests/test_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,17 @@


class TestUpdateBot(GraphQLTest):
mutation_name = "updateBot"
mutation = """
mutation($input: UpdateBotInput!) {
updateBot(input: $input) {
bot {
id
}
errors {
field
messages
}
}
}
"""
Expand All @@ -27,11 +32,13 @@ def test_update_bot_success(self, user, bot):
self.mutate(
login_user=user,
expected_status=200,
input={
"id": self.to_global_id(BotType, bot.id),
"botZipPubliclyDownloadable": True,
"botDataEnabled": True,
"botDataPubliclyDownloadable": True,
variables={
"input": {
"id": self.to_global_id(BotType, bot.id),
"botZipPubliclyDownloadable": True,
"botDataEnabled": True,
"botDataPubliclyDownloadable": True,
}
},
)

Expand All @@ -49,14 +56,16 @@ def test_update_bot_unauthorized(self, user, other_user, bot):
assert bot.user == user
assert bot.bot_zip_publicly_downloadable is False

response = self.mutate(
self.mutate(
login_user=other_user,
input={
"id": self.to_global_id(BotType, bot.id),
"botZipPubliclyDownloadable": True,
variables={
"input": {
"id": self.to_global_id(BotType, bot.id),
"botZipPubliclyDownloadable": True,
}
},
expected_errors=["This is not your bot"],
)
self.assert_graphql_error_like(response, "This is not your bot")

# Verify bot was not updated
bot.refresh_from_db()
Expand All @@ -69,13 +78,33 @@ def test_update_bot_unauthenticated(self, bot):
# We expect the bot fixture to be created with those values
assert bot.bot_zip_publicly_downloadable is False

response = self.mutate(
input={
"id": self.to_global_id(BotType, bot.id),
"botZipPubliclyDownloadable": True,
self.mutate(
variables={
"input": {
"id": self.to_global_id(BotType, bot.id),
"botZipPubliclyDownloadable": True,
}
},
expected_errors=["You are not signed in"],
)

# Optionally, verify bot was not updated
bot.refresh_from_db()
assert bot.bot_zip_publicly_downloadable is False

def test_required_field_not_specified(self, user, bot):
"""
Test updating a bot without being authenticated.
"""
self.mutate(
login_user=user,
variables={
"input": {
"botZipPubliclyDownloadable": True,
}
},
expected_validation_errors={"id": ["Required field"]},
)
self.assert_graphql_error_like(response, "You are not signed in")

# Optionally, verify bot was not updated
bot.refresh_from_db()
Expand Down

0 comments on commit 03e50ab

Please sign in to comment.