From 1c94d261aec569a68a784c480556fdb21fc85bf4 Mon Sep 17 00:00:00 2001 From: apiyo Date: Mon, 24 Jul 2023 04:41:38 +0300 Subject: [PATCH] Use host from request in place of Site url --- onadata/apps/api/tools.py | 8 ++++--- onadata/apps/api/viewsets/dataview_viewset.py | 1 + .../apps/api/viewsets/v2/tableau_viewset.py | 2 +- onadata/apps/api/viewsets/xform_viewset.py | 1 + onadata/apps/logger/views.py | 4 ++-- onadata/apps/main/context_processors.py | 2 +- onadata/apps/viewer/views.py | 4 +++- .../serializers/organization_serializer.py | 17 ++++++++------ .../serializers/user_profile_serializer.py | 2 +- onadata/libs/tests/utils/test_email.py | 1 - onadata/libs/utils/api_export_tools.py | 3 +++ onadata/libs/utils/common_tools.py | 11 ++++------ onadata/libs/utils/csv_builder.py | 22 +++++++++++++------ onadata/libs/utils/export_builder.py | 18 +++++++++++++-- onadata/libs/utils/user_auth.py | 7 +++--- 15 files changed, 67 insertions(+), 36 deletions(-) diff --git a/onadata/apps/api/tools.py b/onadata/apps/api/tools.py index c147a80f13..017de9312f 100644 --- a/onadata/apps/api/tools.py +++ b/onadata/apps/api/tools.py @@ -168,7 +168,9 @@ def create_organization_object(org_name, creator, attrs=None): except IntegrityError as e: raise ValidationError(_(f"{org_name} already exists")) from e if email: - site = Site.objects.get(pk=settings.SITE_ID) + site = ("host" in attrs and attrs["host"]) or Site.objects.get( + pk=settings.SITE_ID + ) registration_profile.send_activation_email(site) profile = OrganizationProfile( user=new_user, @@ -749,9 +751,9 @@ def update_role_by_meta_xform_perms(xform): role.add(user, xform) -def replace_attachment_name_with_url(data): +def replace_attachment_name_with_url(data, request): """Replaces the attachment filename with a URL in ``data`` object.""" - site_url = Site.objects.get_current().domain + site_url = request.get_host() or Site.objects.get_current().domain for record in data: attachments: dict = record.json.get("_attachments") diff --git a/onadata/apps/api/viewsets/dataview_viewset.py b/onadata/apps/api/viewsets/dataview_viewset.py index f252266c0a..1ee101c2e3 100644 --- a/onadata/apps/api/viewsets/dataview_viewset.py +++ b/onadata/apps/api/viewsets/dataview_viewset.py @@ -221,6 +221,7 @@ def export_async(self, request, *args, **kwargs): dataview = self.get_object() xform = dataview.xform options = parse_request_export_options(params) + options["host"] = request.get_host() options.update( { diff --git a/onadata/apps/api/viewsets/v2/tableau_viewset.py b/onadata/apps/api/viewsets/v2/tableau_viewset.py index 74aeff03ba..5e99970d2d 100644 --- a/onadata/apps/api/viewsets/v2/tableau_viewset.py +++ b/onadata/apps/api/viewsets/v2/tableau_viewset.py @@ -224,7 +224,7 @@ def data(self, request, **kwargs): instances = self.paginate_queryset(instances) # Switch out media file names for url links in queryset - data = replace_attachment_name_with_url(instances) + data = replace_attachment_name_with_url(instances, request) data = process_tableau_data( TableauDataSerializer(data, many=True).data, xform ) diff --git a/onadata/apps/api/viewsets/xform_viewset.py b/onadata/apps/api/viewsets/xform_viewset.py index b7e9bc1a3a..46ac045c9e 100644 --- a/onadata/apps/api/viewsets/xform_viewset.py +++ b/onadata/apps/api/viewsets/xform_viewset.py @@ -872,6 +872,7 @@ def export_async(self, request, *args, **kwargs): meta = request.query_params.get("meta") data_id = request.query_params.get("data_id") options = parse_request_export_options(request.query_params) + options["host"] = request.get_host() options.update( { diff --git a/onadata/apps/logger/views.py b/onadata/apps/logger/views.py index 5a3e6aa844..0d842d25d4 100644 --- a/onadata/apps/logger/views.py +++ b/onadata/apps/logger/views.py @@ -91,7 +91,7 @@ def _html_submission_response(request, instance): data = {} data["username"] = instance.xform.user.username data["id_string"] = instance.xform.id_string - data["domain"] = Site.objects.get(id=settings.SITE_ID).domain + data["domain"] = request.get_host() or Site.objects.get(id=settings.SITE_ID).domain return render(request, "submission.html", data) @@ -217,7 +217,7 @@ def formList(request, username): # noqa N802 # unauthorized if user in auth request does not match user in path # unauthorized if user not active if not request.user.is_active: - return HttpResponseNotAuthorized() + return HttpResponseNotAuthorized(host=request.get_host()) # filter private forms (where require_auth=False) # for users who are non-owner diff --git a/onadata/apps/main/context_processors.py b/onadata/apps/main/context_processors.py index 96e115c8c2..cd450cb220 100644 --- a/onadata/apps/main/context_processors.py +++ b/onadata/apps/main/context_processors.py @@ -22,7 +22,7 @@ def site_name(request): """Returns the SITE_NAME/""" site_id = getattr(settings, "SITE_ID", None) try: - site = Site.objects.get(pk=site_id) + site = request.get_host() or Site.objects.get(pk=site_id) except Site.DoesNotExist: name = "example.org" else: diff --git a/onadata/apps/viewer/views.py b/onadata/apps/viewer/views.py index 53e8804394..2cd0b9e464 100644 --- a/onadata/apps/viewer/views.py +++ b/onadata/apps/viewer/views.py @@ -313,7 +313,7 @@ def data_export(request, username, id_string, export_type): # noqa C901 audit = {"xform": xform.id_string, "export_type": export_type} - options = {"extension": extension, "username": username, "id_string": id_string} + options = {"extension": extension, "username": username, "id_string": id_string, "host": request.get_host()} if query: options["query"] = query @@ -435,6 +435,7 @@ def create_export(request, username, id_string, export_type): "remove_group_name": str_to_bool(remove_group_name), "meta": meta.replace(",", "") if meta else None, "google_credentials": credential, + "host": request.get_host(), } try: @@ -510,6 +511,7 @@ def export_list(request, username, id_string, export_type): # noqa C901 "meta": export_meta, "token": export_token, "google_credentials": credential, + "host": request.get_host(), } if should_create_new_export(xform, export_type, options): diff --git a/onadata/libs/serializers/organization_serializer.py b/onadata/libs/serializers/organization_serializer.py index d151728083..0887c830a5 100644 --- a/onadata/libs/serializers/organization_serializer.py +++ b/onadata/libs/serializers/organization_serializer.py @@ -85,6 +85,7 @@ def create(self, validated_data): if "request" in self.context: creator = self.context["request"].user + validated_data["host"] = self.context["request"].get_host() validated_data["organization"] = org_name @@ -130,13 +131,15 @@ def _create_user_list(user_list): except UserProfile.DoesNotExist: profile = UserProfile.objects.create(user=u) - users_list.append({ - "user": u.username, - "role": get_role_in_org(u, obj), - "first_name": u.first_name, - "last_name": u.last_name, - "gravatar": profile.gravatar, - }) + users_list.append( + { + "user": u.username, + "role": get_role_in_org(u, obj), + "first_name": u.first_name, + "last_name": u.last_name, + "gravatar": profile.gravatar, + } + ) return users_list members = get_organization_members(obj) if obj else [] diff --git a/onadata/libs/serializers/user_profile_serializer.py b/onadata/libs/serializers/user_profile_serializer.py index 36ea893b9b..50001cc61a 100644 --- a/onadata/libs/serializers/user_profile_serializer.py +++ b/onadata/libs/serializers/user_profile_serializer.py @@ -272,7 +272,7 @@ def create(self, validated_data): metadata = {} username = params.get("username") password = params.get("password1", "") - site = Site.objects.get(pk=settings.SITE_ID) + site = request.get_host() or Site.objects.get(pk=settings.SITE_ID) new_user = None try: diff --git a/onadata/libs/tests/utils/test_email.py b/onadata/libs/tests/utils/test_email.py index 36c8677738..12cc55c9e5 100644 --- a/onadata/libs/tests/utils/test_email.py +++ b/onadata/libs/tests/utils/test_email.py @@ -225,4 +225,3 @@ def test_url_not_configured(self): """settings.PROJECT_INVITATION_URL not set""" url = get_project_invitation_url(self.custom_request) self.assertEqual(url, "http://testserver/api/v1/profiles") -# Add test case for using a different host diff --git a/onadata/libs/utils/api_export_tools.py b/onadata/libs/utils/api_export_tools.py index 72df74a44d..ee96000e35 100644 --- a/onadata/libs/utils/api_export_tools.py +++ b/onadata/libs/utils/api_export_tools.py @@ -152,6 +152,8 @@ def custom_response_handler( # noqa: C0901 dataview_pk = hasattr(dataview, "pk") and dataview.pk options["dataview_pk"] = dataview_pk + options["host"] = request.get_host() + if dataview: columns_with_hxl = get_columns_with_hxl(xform.survey.get("children")) @@ -249,6 +251,7 @@ def _generate_new_export( # noqa: C0901 "extension": extension, "username": xform.user.username, "id_string": xform.id_string, + "host": request.get_host(), } if query: options["query"] = query diff --git a/onadata/libs/utils/common_tools.py b/onadata/libs/utils/common_tools.py index e5a9e051a1..8eba194e95 100644 --- a/onadata/libs/utils/common_tools.py +++ b/onadata/libs/utils/common_tools.py @@ -237,19 +237,15 @@ def __ne__(self, other): return ComparatorClass -def current_site_url(path): +def current_site_url(path, host): """ Returns fully qualified URL (no trailing slash) for the current site. :param path :return: complete url """ - # pylint: disable=import-outside-toplevel - from django.contrib.sites.models import Site - - current_site = Site.objects.get_current() protocol = getattr(settings, "ONA_SITE_PROTOCOL", "http") port = getattr(settings, "ONA_SITE_PORT", "") - url = f"{protocol}://{current_site.domain}" + url = f"{protocol}://{host}" if port: url += f":{port}" if path: @@ -315,6 +311,7 @@ def get_value_or_attachment_uri( attachment_list=None, show_choice_labels=False, language=None, + host=None, ): """ Gets either the attachment value or the attachment url @@ -339,7 +336,7 @@ def get_value_or_attachment_uri( if a.get("name") == value ] if attachments: - value = current_site_url(attachments[0].get("download_url", "")) + value = current_site_url(attachments[0].get("download_url", ""), host) return value diff --git a/onadata/libs/utils/csv_builder.py b/onadata/libs/utils/csv_builder.py index 0f5182561f..be044a6be9 100644 --- a/onadata/libs/utils/csv_builder.py +++ b/onadata/libs/utils/csv_builder.py @@ -8,6 +8,7 @@ from django.conf import settings from django.db.models.query import QuerySet from django.utils.translation import gettext as _ +from django.contrib.sites.models import Site import unicodecsv as csv from pyxform.question import Question @@ -253,8 +254,8 @@ def __init__( show_choice_labels=True, include_reviews=False, language=None, + host=Site.objects.get(pk=settings.SITE_ID), ): - self.username = username self.id_string = id_string self.filter_query = filter_query @@ -302,6 +303,7 @@ def __init__( self.index_tags = index_tags self.show_choice_labels = show_choice_labels self.language = language + self.host = host self._setup() @@ -464,7 +466,7 @@ def _tag_edit_string(cls, record): @classmethod def _split_gps_fields(cls, record, gps_fields): updated_gps_fields = {} - for (key, value) in iteritems(record): + for key, value in iteritems(record): if key in gps_fields and isinstance(value, str): gps_xpaths = DataDictionary.get_additional_geopoint_xpaths(key) gps_parts = {xpath: None for xpath in gps_xpaths} @@ -552,8 +554,8 @@ def __init__( show_choice_labels=False, include_reviews=False, language=None, + host=Site.objects.get(pk=settings.SITE_ID), ): - super().__init__( username, id_string, @@ -600,6 +602,7 @@ def _reindex( index_tags=DEFAULT_INDEX_TAGS, show_choice_labels=False, language=None, + host=Site.objects.get(pk=settings.SITE_ID), ): """ Flatten list columns by appending an index, otherwise return as is @@ -643,7 +646,7 @@ def get_ordered_repeat_value(xpath, repeat_value): # set within a group. _item = item - for (nested_key, nested_val) in iteritems(_item): + for nested_key, nested_val in iteritems(_item): # given the key "children/details" and nested_key/ # abbreviated xpath # "children/details/immunization/polio_1", @@ -677,6 +680,7 @@ def get_ordered_repeat_value(xpath, repeat_value): index_tags=index_tags, show_choice_labels=show_choice_labels, language=language, + host=host, ) ) else: @@ -698,6 +702,7 @@ def get_ordered_repeat_value(xpath, repeat_value): include_images, show_choice_labels=show_choice_labels, language=language, + host=host, ) else: record[key] = get_value_or_attachment_uri( @@ -708,6 +713,7 @@ def get_ordered_repeat_value(xpath, repeat_value): include_images, show_choice_labels=show_choice_labels, language=language, + host=host, ) else: # anything that's not a list will be in the top level dict so its @@ -724,6 +730,7 @@ def get_ordered_repeat_value(xpath, repeat_value): include_images, show_choice_labels=show_choice_labels, language=language, + host=host, ) return record @@ -763,7 +770,7 @@ def _update_ordered_columns_from_data(self, cursor): """ # add ordered columns for select multiples if self.split_select_multiples: - for (key, choices) in iteritems(self.select_multiples): + for key, choices in iteritems(self.select_multiples): # HACK to ensure choices are NOT duplicated if key in self.ordered_columns.keys(): self.ordered_columns[key] = remove_dups_from_list_maintain_order( @@ -783,7 +790,7 @@ def _update_ordered_columns_from_data(self, cursor): # add ordered columns for nested repeat data for record in cursor: # re index column repeats - for (key, value) in iteritems(record): + for key, value in iteritems(record): self._reindex( key, value, @@ -795,6 +802,7 @@ def _update_ordered_columns_from_data(self, cursor): index_tags=self.index_tags, show_choice_labels=self.show_choice_labels, language=self.language, + host=self.host, ) def _format_for_dataframe(self, cursor): @@ -818,7 +826,7 @@ def _format_for_dataframe(self, cursor): self._tag_edit_string(record) flat_dict = {} # re index repeats - for (key, value) in iteritems(record): + for key, value in iteritems(record): reindexed = self._reindex( key, value, diff --git a/onadata/libs/utils/export_builder.py b/onadata/libs/utils/export_builder.py index 705a079c81..dd53ab2639 100644 --- a/onadata/libs/utils/export_builder.py +++ b/onadata/libs/utils/export_builder.py @@ -106,7 +106,7 @@ def encode_if_str(row, key, encode_dates=False, sav_writer=None): # pylint: disable=too-many-arguments,too-many-locals,too-many-branches -def dict_to_joined_export(data, index, indices, name, survey, row, media_xpaths=None): +def dict_to_joined_export(data, index, indices, name, survey, row, host, media_xpaths=None): """ Converts a dict into one or more tabular datasets :param data: current record which can be changed or updated @@ -129,7 +129,7 @@ def dict_to_joined_export(data, index, indices, name, survey, row, media_xpaths= indices[key] += 1 child_index = indices[key] new_output = dict_to_joined_export( - child, child_index, indices, key, survey, row, media_xpaths + child, child_index, indices, key, survey, row, host, media_xpaths ) item = { INDEX: child_index, @@ -163,6 +163,7 @@ def dict_to_joined_export(data, index, indices, name, survey, row, media_xpaths= data_dictionary, media_xpaths, row and row.get(ATTACHMENTS), + host=host, ) return output @@ -922,6 +923,8 @@ def write_row(row, csv_writer, fields): index = 1 indices = {} survey_name = self.survey.name + options=kwargs.get("options") + host = options.get("host") for i, row_data in enumerate(data, start=1): # decode mongo section names joined_export = dict_to_joined_export( @@ -931,6 +934,7 @@ def write_row(row, csv_writer, fields): survey_name, self.survey, row_data, + host, media_xpaths, ) output = decode_mongo_encoded_section_names(joined_export) @@ -1060,6 +1064,9 @@ def write_row(data, work_sheet, fields, work_sheet_titles): index = 1 indices = {} survey_name = self.survey.name + + options = kwargs.get("options") + host = options.get("host") for i, row_data in enumerate(data, start=1): joined_export = dict_to_joined_export( row_data, @@ -1068,6 +1075,7 @@ def write_row(data, work_sheet, fields, work_sheet_titles): survey_name, self.survey, row_data, + host, media_xpaths, ) output = decode_mongo_encoded_section_names(joined_export) @@ -1122,6 +1130,7 @@ def to_flat_csv_export( xform = kwargs.get("xform") options = kwargs.get("options") total_records = kwargs.get("total_records") + host = options.get("host") win_excel_utf8 = options.get("win_excel_utf8") if options else False index_tags = options.get(REPEAT_INDEX_TAGS, self.REPEAT_INDEX_TAGS) show_choice_labels = options.get("show_choice_labels", False) @@ -1149,6 +1158,7 @@ def to_flat_csv_export( show_choice_labels=show_choice_labels, include_reviews=self.INCLUDE_REVIEWS, language=language, + host=host ) csv_builder.export_to(path, dataview=dataview) @@ -1394,6 +1404,9 @@ def write_row(row, sav_writer, fields): index = 1 indices = {} survey_name = self.survey.name + + options = kwargs.get("options") + host = options.get("host") for i, row_data in enumerate(data, start=1): # decode mongo section names joined_export = dict_to_joined_export( @@ -1403,6 +1416,7 @@ def write_row(row, sav_writer, fields): survey_name, self.survey, row_data, + host, media_xpaths, ) output = decode_mongo_encoded_section_names(joined_export) diff --git a/onadata/libs/utils/user_auth.py b/onadata/libs/utils/user_auth.py index 5e3d9c75cc..85c585799f 100644 --- a/onadata/libs/utils/user_auth.py +++ b/onadata/libs/utils/user_auth.py @@ -34,9 +34,10 @@ class HttpResponseNotAuthorized(HttpResponse): status_code = 401 - def __init__(self): + def __init__(self, host=Site.objects.get_current().name, *args, **kwargs): HttpResponse.__init__(self) - self["WWW-Authenticate"] = f'Basic realm="{Site.objects.get_current().name}"' + self["WWW-Authenticate"] = f'Basic realm="{host}"' + super().__init__(*args, **kwargs) def check_and_set_user(request, username): @@ -180,7 +181,7 @@ def helper_auth_helper(request): request.user = user return None - return HttpResponseNotAuthorized() + return HttpResponseNotAuthorized(host=request.get_host()) def basic_http_auth(func):