Skip to content

Commit

Permalink
Use host from request in place of Site url
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankApiyo committed Jul 24, 2023
1 parent 30dc934 commit 1c94d26
Show file tree
Hide file tree
Showing 15 changed files with 67 additions and 36 deletions.
8 changes: 5 additions & 3 deletions onadata/apps/api/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,9 @@ def create_organization_object(org_name, creator, attrs=None):
except IntegrityError as e:
raise ValidationError(_(f"{org_name} already exists")) from e
if email:
site = Site.objects.get(pk=settings.SITE_ID)
site = ("host" in attrs and attrs["host"]) or Site.objects.get(
pk=settings.SITE_ID
)
registration_profile.send_activation_email(site)
profile = OrganizationProfile(
user=new_user,
Expand Down Expand Up @@ -749,9 +751,9 @@ def update_role_by_meta_xform_perms(xform):
role.add(user, xform)


def replace_attachment_name_with_url(data):
def replace_attachment_name_with_url(data, request):
"""Replaces the attachment filename with a URL in ``data`` object."""
site_url = Site.objects.get_current().domain
site_url = request.get_host() or Site.objects.get_current().domain

for record in data:
attachments: dict = record.json.get("_attachments")
Expand Down
1 change: 1 addition & 0 deletions onadata/apps/api/viewsets/dataview_viewset.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def export_async(self, request, *args, **kwargs):
dataview = self.get_object()
xform = dataview.xform
options = parse_request_export_options(params)
options["host"] = request.get_host()

options.update(
{
Expand Down
2 changes: 1 addition & 1 deletion onadata/apps/api/viewsets/v2/tableau_viewset.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def data(self, request, **kwargs):
instances = self.paginate_queryset(instances)

# Switch out media file names for url links in queryset
data = replace_attachment_name_with_url(instances)
data = replace_attachment_name_with_url(instances, request)
data = process_tableau_data(
TableauDataSerializer(data, many=True).data, xform
)
Expand Down
1 change: 1 addition & 0 deletions onadata/apps/api/viewsets/xform_viewset.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,7 @@ def export_async(self, request, *args, **kwargs):
meta = request.query_params.get("meta")
data_id = request.query_params.get("data_id")
options = parse_request_export_options(request.query_params)
options["host"] = request.get_host()

options.update(
{
Expand Down
4 changes: 2 additions & 2 deletions onadata/apps/logger/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _html_submission_response(request, instance):
data = {}
data["username"] = instance.xform.user.username
data["id_string"] = instance.xform.id_string
data["domain"] = Site.objects.get(id=settings.SITE_ID).domain
data["domain"] = request.get_host() or Site.objects.get(id=settings.SITE_ID).domain

return render(request, "submission.html", data)

Expand Down Expand Up @@ -217,7 +217,7 @@ def formList(request, username): # noqa N802
# unauthorized if user in auth request does not match user in path
# unauthorized if user not active
if not request.user.is_active:
return HttpResponseNotAuthorized()
return HttpResponseNotAuthorized(host=request.get_host())

# filter private forms (where require_auth=False)
# for users who are non-owner
Expand Down
2 changes: 1 addition & 1 deletion onadata/apps/main/context_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def site_name(request):
"""Returns the SITE_NAME/"""
site_id = getattr(settings, "SITE_ID", None)
try:
site = Site.objects.get(pk=site_id)
site = request.get_host() or Site.objects.get(pk=site_id)
except Site.DoesNotExist:
name = "example.org"
else:
Expand Down
4 changes: 3 additions & 1 deletion onadata/apps/viewer/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def data_export(request, username, id_string, export_type): # noqa C901

audit = {"xform": xform.id_string, "export_type": export_type}

options = {"extension": extension, "username": username, "id_string": id_string}
options = {"extension": extension, "username": username, "id_string": id_string, "host": request.get_host()}
if query:
options["query"] = query

Expand Down Expand Up @@ -435,6 +435,7 @@ def create_export(request, username, id_string, export_type):
"remove_group_name": str_to_bool(remove_group_name),
"meta": meta.replace(",", "") if meta else None,
"google_credentials": credential,
"host": request.get_host(),
}

try:
Expand Down Expand Up @@ -510,6 +511,7 @@ def export_list(request, username, id_string, export_type): # noqa C901
"meta": export_meta,
"token": export_token,
"google_credentials": credential,
"host": request.get_host(),
}

if should_create_new_export(xform, export_type, options):
Expand Down
17 changes: 10 additions & 7 deletions onadata/libs/serializers/organization_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def create(self, validated_data):

if "request" in self.context:
creator = self.context["request"].user
validated_data["host"] = self.context["request"].get_host()

validated_data["organization"] = org_name

Expand Down Expand Up @@ -130,13 +131,15 @@ def _create_user_list(user_list):
except UserProfile.DoesNotExist:
profile = UserProfile.objects.create(user=u)

users_list.append({
"user": u.username,
"role": get_role_in_org(u, obj),
"first_name": u.first_name,
"last_name": u.last_name,
"gravatar": profile.gravatar,
})
users_list.append(
{
"user": u.username,
"role": get_role_in_org(u, obj),
"first_name": u.first_name,
"last_name": u.last_name,
"gravatar": profile.gravatar,
}
)
return users_list

members = get_organization_members(obj) if obj else []
Expand Down
2 changes: 1 addition & 1 deletion onadata/libs/serializers/user_profile_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def create(self, validated_data):
metadata = {}
username = params.get("username")
password = params.get("password1", "")
site = Site.objects.get(pk=settings.SITE_ID)
site = request.get_host() or Site.objects.get(pk=settings.SITE_ID)
new_user = None

try:
Expand Down
1 change: 0 additions & 1 deletion onadata/libs/tests/utils/test_email.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,4 +225,3 @@ def test_url_not_configured(self):
"""settings.PROJECT_INVITATION_URL not set"""
url = get_project_invitation_url(self.custom_request)
self.assertEqual(url, "http://testserver/api/v1/profiles")
# Add test case for using a different host
3 changes: 3 additions & 0 deletions onadata/libs/utils/api_export_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def custom_response_handler( # noqa: C0901
dataview_pk = hasattr(dataview, "pk") and dataview.pk
options["dataview_pk"] = dataview_pk

options["host"] = request.get_host()

if dataview:
columns_with_hxl = get_columns_with_hxl(xform.survey.get("children"))

Expand Down Expand Up @@ -249,6 +251,7 @@ def _generate_new_export( # noqa: C0901
"extension": extension,
"username": xform.user.username,
"id_string": xform.id_string,
"host": request.get_host(),
}
if query:
options["query"] = query
Expand Down
11 changes: 4 additions & 7 deletions onadata/libs/utils/common_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,19 +237,15 @@ def __ne__(self, other):
return ComparatorClass


def current_site_url(path):
def current_site_url(path, host):
"""
Returns fully qualified URL (no trailing slash) for the current site.
:param path
:return: complete url
"""
# pylint: disable=import-outside-toplevel
from django.contrib.sites.models import Site

current_site = Site.objects.get_current()
protocol = getattr(settings, "ONA_SITE_PROTOCOL", "http")
port = getattr(settings, "ONA_SITE_PORT", "")
url = f"{protocol}://{current_site.domain}"
url = f"{protocol}://{host}"
if port:
url += f":{port}"
if path:
Expand Down Expand Up @@ -315,6 +311,7 @@ def get_value_or_attachment_uri(
attachment_list=None,
show_choice_labels=False,
language=None,
host=None,
):
"""
Gets either the attachment value or the attachment url
Expand All @@ -339,7 +336,7 @@ def get_value_or_attachment_uri(
if a.get("name") == value
]
if attachments:
value = current_site_url(attachments[0].get("download_url", ""))
value = current_site_url(attachments[0].get("download_url", ""), host)

return value

Expand Down
22 changes: 15 additions & 7 deletions onadata/libs/utils/csv_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from django.conf import settings
from django.db.models.query import QuerySet
from django.utils.translation import gettext as _
from django.contrib.sites.models import Site

import unicodecsv as csv
from pyxform.question import Question
Expand Down Expand Up @@ -253,8 +254,8 @@ def __init__(
show_choice_labels=True,
include_reviews=False,
language=None,
host=Site.objects.get(pk=settings.SITE_ID),
):

self.username = username
self.id_string = id_string
self.filter_query = filter_query
Expand Down Expand Up @@ -302,6 +303,7 @@ def __init__(
self.index_tags = index_tags
self.show_choice_labels = show_choice_labels
self.language = language
self.host = host

self._setup()

Expand Down Expand Up @@ -464,7 +466,7 @@ def _tag_edit_string(cls, record):
@classmethod
def _split_gps_fields(cls, record, gps_fields):
updated_gps_fields = {}
for (key, value) in iteritems(record):
for key, value in iteritems(record):
if key in gps_fields and isinstance(value, str):
gps_xpaths = DataDictionary.get_additional_geopoint_xpaths(key)
gps_parts = {xpath: None for xpath in gps_xpaths}
Expand Down Expand Up @@ -552,8 +554,8 @@ def __init__(
show_choice_labels=False,
include_reviews=False,
language=None,
host=Site.objects.get(pk=settings.SITE_ID),
):

super().__init__(
username,
id_string,
Expand Down Expand Up @@ -600,6 +602,7 @@ def _reindex(
index_tags=DEFAULT_INDEX_TAGS,
show_choice_labels=False,
language=None,
host=Site.objects.get(pk=settings.SITE_ID),
):
"""
Flatten list columns by appending an index, otherwise return as is
Expand Down Expand Up @@ -643,7 +646,7 @@ def get_ordered_repeat_value(xpath, repeat_value):
# set within a group.
_item = item

for (nested_key, nested_val) in iteritems(_item):
for nested_key, nested_val in iteritems(_item):
# given the key "children/details" and nested_key/
# abbreviated xpath
# "children/details/immunization/polio_1",
Expand Down Expand Up @@ -677,6 +680,7 @@ def get_ordered_repeat_value(xpath, repeat_value):
index_tags=index_tags,
show_choice_labels=show_choice_labels,
language=language,
host=host,
)
)
else:
Expand All @@ -698,6 +702,7 @@ def get_ordered_repeat_value(xpath, repeat_value):
include_images,
show_choice_labels=show_choice_labels,
language=language,
host=host,
)
else:
record[key] = get_value_or_attachment_uri(
Expand All @@ -708,6 +713,7 @@ def get_ordered_repeat_value(xpath, repeat_value):
include_images,
show_choice_labels=show_choice_labels,
language=language,
host=host,
)
else:
# anything that's not a list will be in the top level dict so its
Expand All @@ -724,6 +730,7 @@ def get_ordered_repeat_value(xpath, repeat_value):
include_images,
show_choice_labels=show_choice_labels,
language=language,
host=host,
)
return record

Expand Down Expand Up @@ -763,7 +770,7 @@ def _update_ordered_columns_from_data(self, cursor):
"""
# add ordered columns for select multiples
if self.split_select_multiples:
for (key, choices) in iteritems(self.select_multiples):
for key, choices in iteritems(self.select_multiples):
# HACK to ensure choices are NOT duplicated
if key in self.ordered_columns.keys():
self.ordered_columns[key] = remove_dups_from_list_maintain_order(
Expand All @@ -783,7 +790,7 @@ def _update_ordered_columns_from_data(self, cursor):
# add ordered columns for nested repeat data
for record in cursor:
# re index column repeats
for (key, value) in iteritems(record):
for key, value in iteritems(record):
self._reindex(
key,
value,
Expand All @@ -795,6 +802,7 @@ def _update_ordered_columns_from_data(self, cursor):
index_tags=self.index_tags,
show_choice_labels=self.show_choice_labels,
language=self.language,
host=self.host,
)

def _format_for_dataframe(self, cursor):
Expand All @@ -818,7 +826,7 @@ def _format_for_dataframe(self, cursor):
self._tag_edit_string(record)
flat_dict = {}
# re index repeats
for (key, value) in iteritems(record):
for key, value in iteritems(record):
reindexed = self._reindex(
key,
value,
Expand Down
Loading

0 comments on commit 1c94d26

Please sign in to comment.