diff --git a/evap/evaluation/tools.py b/evap/evaluation/tools.py index 68bb8f702b..07020c75cc 100644 --- a/evap/evaluation/tools.py +++ b/evap/evaluation/tools.py @@ -2,14 +2,15 @@ from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import Iterable, Mapping -from typing import Any, TypeVar +from typing import Any, Protocol, TypeVar from urllib.parse import quote import xlwt from django.conf import settings from django.core.exceptions import SuspiciousOperation, ValidationError from django.db.models import Model -from django.http import HttpResponse +from django.forms.formsets import BaseFormSet +from django.http import HttpRequest, HttpResponse from django.shortcuts import get_object_or_404 from django.utils.translation import get_language from django.views.generic import FormView @@ -42,7 +43,7 @@ def get_object_from_dict_pk_entry_or_logged_40x(model_cls: type[M], dict_obj: Ma raise SuspiciousOperation from e -def is_prefetched(instance, attribute_name: str): +def is_prefetched(instance, attribute_name: str) -> bool: """ Is the given related attribute prefetched? Can be used to do ordering or counting in python and avoid additional database queries @@ -58,7 +59,7 @@ def is_prefetched(instance, attribute_name: str): return False -def discard_cached_related_objects(instance): +def discard_cached_related_objects(instance: M) -> M: """ Discard all cached related objects (for ForeignKey and M2M Fields). Useful if there were changes, but django's caching would still give us the old @@ -66,44 +67,44 @@ def discard_cached_related_objects(instance): hierarchy (e.g. for storing instances in a cache) """ # Extracted from django's refresh_from_db, which sadly doesn't offer this part alone (without hitting the DB). - for field in instance._meta.concrete_fields: + for field in instance._meta.concrete_fields: # type: ignore if field.is_relation and field.is_cached(instance): field.delete_cached_value(instance) - for field in instance._meta.related_objects: + for field in instance._meta.related_objects: # type: ignore if field.is_cached(instance): field.delete_cached_value(instance) - instance._prefetched_objects_cache = {} + instance._prefetched_objects_cache = {} # type: ignore return instance -def is_external_email(email): +def is_external_email(email: str) -> bool: return not any(email.endswith("@" + domain) for domain in settings.INSTITUTION_EMAIL_DOMAINS) -def sort_formset(request, formset): +def sort_formset(request: HttpRequest, formset: BaseFormSet) -> None: if request.POST: # if not, there will be no cleaned_data and the models should already be sorted anyways formset.is_valid() # make sure all forms have cleaned_data formset.forms.sort(key=lambda f: f.cleaned_data.get("order", 9001)) -def date_to_datetime(date): +def date_to_datetime(date: datetime.date) -> datetime.datetime: return datetime.datetime(year=date.year, month=date.month, day=date.day) -def vote_end_datetime(vote_end_date): +def vote_end_datetime(vote_end_date: datetime.date) -> datetime.datetime: # The evaluation actually ends at EVALUATION_END_OFFSET_HOURS:00 of the day AFTER self.vote_end_date. return date_to_datetime(vote_end_date) + datetime.timedelta(hours=24 + settings.EVALUATION_END_OFFSET_HOURS) -def get_parameter_from_url_or_session(request, parameter, default=False): - result = request.GET.get(parameter, None) - if result is None: # if no parameter is given take session value +def get_parameter_from_url_or_session(request: HttpRequest, parameter: str, default=False) -> bool: + result_str = request.GET.get(parameter, None) + if result_str is None: # if no parameter is given take session value result = request.session.get(parameter, default) else: - result = {"true": True, "false": False}.get(result.lower()) # convert parameter to boolean + result = {"true": True, "false": False}.get(result_str.lower()) # convert parameter to boolean request.session[parameter] = result # store value for session return result @@ -115,7 +116,10 @@ def translate(**kwargs): return property(lambda self: getattr(self, kwargs[get_language() or "en"])) -def clean_email(email): +EmailT = TypeVar("EmailT", str, None) + + +def clean_email(email: EmailT) -> EmailT: if email: email = email.strip().lower() # Replace email domains in case there are multiple alias domains used in the organisation and all emails should @@ -126,11 +130,11 @@ def clean_email(email): return email -def capitalize_first(string): +def capitalize_first(string: str) -> str: return string[0].upper() + string[1:] -def ilen(iterable): +def ilen(iterable: Iterable) -> int: return sum(1 for _ in iterable) @@ -148,7 +152,7 @@ class FormsetView(FormView): def form_class(self): return self.formset_class - def get_context_data(self, **kwargs): + def get_context_data(self, **kwargs) -> dict[str, Any]: context = super().get_context_data(**kwargs) context["formset"] = context.pop("form") return context @@ -157,19 +161,24 @@ def get_context_data(self, **kwargs): # `get_formset_kwargs`. Users can thus override `get_formset_kwargs` instead. If it is not overridden, we delegate # to the original `get_form_kwargs` instead. The same approach is used for the other renamed methods. - def get_form_kwargs(self): + def get_form_kwargs(self) -> dict: return self.get_formset_kwargs() - def get_formset_kwargs(self): + def get_formset_kwargs(self) -> dict: return super().get_form_kwargs() - def form_valid(self, form): + def form_valid(self, form) -> HttpResponse: return self.formset_valid(form) - def formset_valid(self, formset): + def formset_valid(self, formset) -> HttpResponse: return super().form_valid(formset) +class HasFormValid(Protocol): + def form_valid(self, form): + pass + + class SaveValidFormMixin: """ Call `form.save()` if the submitted form is valid. @@ -178,7 +187,7 @@ class SaveValidFormMixin: example if a formset for a collection of objects is submitted. """ - def form_valid(self, form): + def form_valid(self: HasFormValid, form) -> HttpResponse: form.save() return super().form_valid(form) @@ -193,11 +202,11 @@ class AttachmentResponse(HttpResponse): _to the response instance_ as if it was a writable file. """ - def __init__(self, filename, content_type=None, **kwargs): + def __init__(self, filename: str, content_type=None, **kwargs) -> None: super().__init__(content_type=content_type, **kwargs) self.set_content_disposition(filename) - def set_content_disposition(self, filename): + def set_content_disposition(self, filename: str) -> None: try: filename.encode("ascii") self["Content-Disposition"] = f'attachment; filename="{filename}"' @@ -215,7 +224,7 @@ class HttpResponseNoContent(HttpResponse): status_code = 204 - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) del self["content-type"] @@ -244,7 +253,7 @@ class ExcelExporter(ABC): # have a sheet added at initialization. default_sheet_name: str | None = None - def __init__(self): + def __init__(self) -> None: self.workbook = xlwt.Workbook() self.cur_row = 0 self.cur_col = 0 @@ -253,7 +262,7 @@ def __init__(self): else: self.cur_sheet = None - def write_cell(self, label="", style="default"): + def write_cell(self, label: str | None = "", style: str = "default") -> None: """Write a single cell and move to the next column.""" self.cur_sheet.write( self.cur_row, @@ -263,11 +272,11 @@ def write_cell(self, label="", style="default"): ) self.cur_col += 1 - def next_row(self): + def next_row(self) -> None: self.cur_col = 0 self.cur_row += 1 - def write_row(self, vals, style="default"): + def write_row(self, vals: Iterable[str], style: str = "default") -> None: """ Write a cell for every value and go to the next row. Styling can be chosen @@ -278,16 +287,16 @@ def write_row(self, vals, style="default"): self.write_cell(val, style=style(val) if callable(style) else style) self.next_row() - def write_empty_row_with_styles(self, styles): + def write_empty_row_with_styles(self, styles: Iterable[str]) -> None: for style in styles: self.write_cell(None, style) self.next_row() @abstractmethod - def export_impl(self, *args, **kwargs): + def export_impl(self, *args, **kwargs) -> None: """Specify the logic to insert the data into the sheet here.""" - def export(self, response, *args, **kwargs): + def export(self, response, *args, **kwargs) -> None: """Convenience method to avoid some boilerplate.""" self.export_impl(*args, **kwargs) self.workbook.save(response) diff --git a/evap/grades/views.py b/evap/grades/views.py index fa8b0a8119..88d779df9d 100644 --- a/evap/grades/views.py +++ b/evap/grades/views.py @@ -1,3 +1,5 @@ +from typing import Any + from django.conf import settings from django.contrib import messages from django.core.exceptions import PermissionDenied, SuspiciousOperation @@ -23,7 +25,7 @@ class IndexView(TemplateView): template_name = "grades_index.html" - def get_context_data(self, **kwargs): + def get_context_data(self, **kwargs) -> dict[str, Any]: return super().get_context_data(**kwargs) | { "semesters": Semester.objects.filter(grade_documents_are_deleted=False), "disable_breadcrumb_grades": True, @@ -51,19 +53,19 @@ class SemesterView(DetailView): object: Semester - def get_object(self, *args, **kwargs): + def get_object(self, *args, **kwargs) -> Semester: semester = super().get_object(*args, **kwargs) if semester.grade_documents_are_deleted: raise PermissionDenied return semester - def get_context_data(self, **kwargs): - courses = ( + def get_context_data(self, **kwargs) -> dict[str, Any]: + query = ( self.object.courses.filter(evaluations__wait_for_grade_upload_before_publishing=True) .exclude(evaluations__state=Evaluation.State.NEW) .distinct() ) - courses = course_grade_document_count_tuples(courses) + courses = course_grade_document_count_tuples(query) return super().get_context_data(**kwargs) | { "courses": courses, @@ -77,13 +79,13 @@ class CourseView(DetailView): model = Course pk_url_kwarg = "course_id" - def get_object(self, *args, **kwargs): + def get_object(self, *args, **kwargs) -> Course: course = super().get_object(*args, **kwargs) if course.semester.grade_documents_are_deleted: raise PermissionDenied return course - def get_context_data(self, **kwargs): + def get_context_data(self, **kwargs) -> dict[str, Any]: return super().get_context_data(**kwargs) | { "semester": self.object.semester, "grade_documents": self.object.grade_documents.all(), diff --git a/evap/staff/views.py b/evap/staff/views.py index 9e7e117f83..cf0d6346f8 100644 --- a/evap/staff/views.py +++ b/evap/staff/views.py @@ -14,7 +14,7 @@ from django.db import IntegrityError, transaction from django.db.models import BooleanField, Case, Count, ExpressionWrapper, IntegerField, Prefetch, Q, Sum, When from django.dispatch import receiver -from django.forms import formset_factory +from django.forms import BaseForm, formset_factory from django.forms.models import inlineformset_factory, modelformset_factory from django.http import Http404, HttpRequest, HttpResponse, HttpResponseBadRequest, HttpResponseRedirect from django.shortcuts import get_object_or_404, redirect, render @@ -580,7 +580,8 @@ class SemesterCreateView(SuccessMessageMixin, CreateView): form_class = SemesterForm success_message = gettext_lazy("Successfully created semester.") - def get_success_url(self): + def get_success_url(self) -> str: + assert self.object is not None return reverse("staff:semester_view", args=[self.object.id]) @@ -592,7 +593,7 @@ class SemesterEditView(SuccessMessageMixin, UpdateView): pk_url_kwarg = "semester_id" success_message = gettext_lazy("Successfully updated semester.") - def get_success_url(self): + def get_success_url(self) -> str: return reverse("staff:semester_view", args=[self.object.id]) @@ -1050,13 +1051,13 @@ class CourseEditView(SuccessMessageMixin, UpdateView): object: Course - def get_object(self, *args, **kwargs): + def get_object(self, *args, **kwargs) -> Course: course = super().get_object(*args, **kwargs) if self.request.method == "POST" and not course.can_be_edited_by_manager: raise SuspiciousOperation("Modifying this course is not allowed.") return course - def get_context_data(self, **kwargs): + def get_context_data(self, **kwargs) -> dict[str, Any]: context_data = super().get_context_data(**kwargs) | { "semester": self.object.semester, "editable": self.object.can_be_edited_by_manager, @@ -1065,7 +1066,9 @@ def get_context_data(self, **kwargs): context_data["course_form"] = context_data.pop("form") return context_data - def form_valid(self, form): + def form_valid(self, form: BaseForm) -> HttpResponse: + assert isinstance(form, CourseForm) # https://www.github.com/typeddjango/django-stubs/issues/1809 + if self.request.POST.get("operation") not in ("save", "save_create_evaluation", "save_create_single_result"): raise SuspiciousOperation("Invalid POST operation") @@ -1074,7 +1077,7 @@ def form_valid(self, form): update_template_cache_of_published_evaluations_in_course(self.object) return response - def get_success_url(self): + def get_success_url(self) -> str: match self.request.POST["operation"]: case "save": return reverse("staff:semester_view", args=[self.object.semester.id]) @@ -1082,6 +1085,7 @@ def get_success_url(self): return reverse("staff:evaluation_create_for_course", args=[self.object.id]) case "save_create_single_result": return reverse("staff:single_result_create_for_course", args=[self.object.id]) + raise SuspiciousOperation("Unexpected operation") @require_POST @@ -2290,7 +2294,7 @@ class UserMergeSelectionView(FormView): form_class = UserMergeSelectionForm template_name = "staff_user_merge_selection.html" - def form_valid(self, form): + def form_valid(self, form: UserMergeSelectionForm) -> HttpResponse: return redirect( "staff:user_merge", form.cleaned_data["main_user"].id, @@ -2334,7 +2338,7 @@ class TemplateEditView(SuccessMessageMixin, UpdateView): success_url = reverse_lazy("staff:index") template_name = "staff_template_form.html" - def get_context_data(self, **kwargs) -> dict: + def get_context_data(self, **kwargs) -> dict[str, Any]: context = super().get_context_data(**kwargs) template = context["template"] = context.pop("emailtemplate") @@ -2377,7 +2381,7 @@ class FaqIndexView(SuccessMessageMixin, SaveValidFormMixin, FormsetView): success_url = reverse_lazy("staff:faq_index") success_message = gettext_lazy("Successfully updated the FAQ sections.") - def get_context_data(self, **kwargs): + def get_context_data(self, **kwargs) -> dict[str, Any]: return super().get_context_data(**kwargs) | {"sections": FaqSection.objects.all()}