Skip to content

Commit

Permalink
fix: replace manual ordering with SQL decryption
Browse files Browse the repository at this point in the history
Following @vshvechko's advice, replaced manual ordering strategy with
use of existing SQL decryption function, `decrypt_internal`, which
greatly simplifies code and makes it more maintainable. Preserved record
limit to ensure encrypted ordering requests do not overload SQL server.
  • Loading branch information
farmerpaul committed May 31, 2024
1 parent c4e513a commit b406b29
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 188 deletions.
208 changes: 82 additions & 126 deletions src/apps/shared/ordering.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from functools import cmp_to_key
from typing import Literal, cast
from typing import Literal

from sqlalchemy import Column, asc, desc

Expand All @@ -10,49 +9,59 @@
from sqlalchemy.orm import InstrumentedAttribute

OrderingDirection = Literal["+", "-"]
OrderingField = str | tuple[str, int]

# Record limit for ordering by encrypted fields to minimize performance impact
ENCRYPTED_ORDERING_LIMIT = 300


class Ordering:
"""
Generates clauses for SQL ordering and utilities for manual ordering & pagination, if needed.
Generates clauses for SQL ordering, with support for conditional ordering by encrypted fields.
For SQL ordering, add supported fields directly as class attributes as schema columns or
For unencrypted fields, define them directly as class attributes as schema columns or
ordering clauses.
For manual ordering (i.e. required by encrypted fields), add them to the manual_fields
dictionary, with the key representing the requested ordering field, and the value representing
the column name as returned in each result record.
**Note:** Also include in this dictionary any fields that are also supported by SQL
ordering so that multi-key ordering continues to be supported when ordering manually.
For encrypted fields, add them to the encrypted_fields dictionary as ordering clauses
(including decryption logic).
Clauses returned by get_clauses will only include requested encrypted fields only if record
count is below ENCRYPTED_ORDERING_LIMIT.
Example:
class ExampleOrdering(Ordering):
```
class BasicOrdering(Ordering):
id = Schema.id
date = Schema.created_at
ordering = BasicOrdering()
query.order_by(*ordering.get_clauses('-id', 'date'))
# Will give result as SQL:
# select * from schema order by id desc, created_at asc
class EncryptedOrdering(Ordering):
id = Schema.id
date = Schema.created_at
# manual_fields is optional, only needed if one or more fields are encrypted
manual_fields = {
"email": "email_encrypted", # encrypted scalar field
"nicknames": ("nicknames", 0), # encrypted array field, index 0
"id": "id",
"date": "created_at",
# encrypted_fields is optional, only needed if one or more fields are encrypted
encrypted_fields = {
"email": Ordering.Clause(func.decrypt_internal(UserSchema.email, get_key())),
}
ordering = ExampleOrdering()
query.order_by(*ordering.get_clauses('-id', 'date'))
# will give result as SQL:
# select * from schema order by id desc, created_at asc
manual_fields = ordering.get_manual_fields("nicknames", "email", "-date")
if manual_fields:
# run SQL query without ORDER BY query
# then manually sort the results using manual_sort utility:
data = ordering.manual_sort(data, manual_fields)
# manually sorts decrypted data by:
# nicknames[0] (asc), email_encrypted (asc), then created_at (desc)
data = ordering.manual_paginate(data, 1, 10)
# manually paginates data, returns slice of first 10 records
ordering = EncryptedOrdering()
count = (await session.execute(
select(count()).select_from(query.with_only_columns(Schema.id).subquery())
)).scalar()
query.order_by(*ordering.get_clauses_encrypted("email", "-date", count=count))
# If count < ENCRYPTED_ORDERING_LIMIT, will give result as SQL:
# select * from schema order by decrypt_internal(email, …) asc, created_at desc
# and EncryptedOrdering().get_ordering_fields(count) will return ["id", "date", "email"]
#
# Else if count >= ENCRYPTED_ORDERING_LIMIT, will give result as SQL:
# select * from schema order by created_at desc
# and EncryptedOrdering().get_ordering_fields(count) will return ["id", "date"]
```
"""

class Clause:
Expand All @@ -63,70 +72,71 @@ class Clause:
def __init__(self, clause):
self.clause = clause

sql_fields: dict[str, Column] = dict()
manual_fields: dict[str, OrderingField] = dict()
fields: dict[str, InstrumentedAttribute | Column | Clause] = dict()
encrypted_fields: dict[str, InstrumentedAttribute | Column | Clause] = dict()

actions = {
"+": asc,
"-": desc,
}

def __init__(self):
self.sql_fields = dict()
self.fields = dict()
for key, val in self.__class__.__dict__.items():
_val = None
if isinstance(val, (InstrumentedAttribute, Column)):
_val = val
elif isinstance(val, Ordering.Clause):
_val = val.clause
if _val is not None:
self.sql_fields[key] = _val
self.fields[key] = _val

# Returns array of parsed ordering arguments if any of the fields require manual sorting
# (defined is self.manual_fields, but not in self.fields), else returns None.
def get_manual_fields(self, *args: str):
has_manual_only_fields = False
parsed_fields: list[tuple[OrderingDirection, OrderingField]] = []
for value in args:
parsed_field = self._parse_ordered_field(value)
if parsed_field is None:
continue
direction, field = parsed_field
if not self._is_manual_field(field):
continue
parsed_fields.append((direction, self.manual_fields[field]))
# Flag if any of the fields cannot be sorted by SQL (i.e. manual only)
if not self._is_sql_field(field):
has_manual_only_fields = True
return parsed_fields if has_manual_only_fields else None

# Returns SQL clauses suitable for Query.order_by based on provided "+field"-style args
def get_clauses(self, *args):
for key, val in self.encrypted_fields.items():
_val = None
if isinstance(val, (InstrumentedAttribute, Column)):
_val = val
elif isinstance(val, Ordering.Clause):
_val = val.clause
if _val is not None:
self.encrypted_fields[key] = _val

def get_clauses(self, *args: str, count: int = 0):
"""
Returns SQL clauses suitable for Query.order_by based on provided "+field"-style args,
including any requested encrypted fields only if provided record count is below
ENCRYPTED_ORDERING_LIMIT.
"""
clauses: list[str] = []
for value in args:
parsed_field = self._parse_ordered_field(value)
if parsed_field is None:
continue
_direction, field = parsed_field
if not self._is_sql_field(field):

direction, field = parsed_field
clause = None
if field in self.fields:
clause = self._prepare_sql_clause((direction, self.fields[field]))
elif field in self.encrypted_fields and count < ENCRYPTED_ORDERING_LIMIT:
clause = self._prepare_sql_clause((direction, self.encrypted_fields[field]))
else:
continue
clause = self._prepare_sql_clause(parsed_field)

if clause is not None:
clauses.append(clause)
return clauses

# Returns list of fields that this Ordering class supports for sorting, based on whether
# to use manual sorting or not. Returns field names in camelCase suitable for API output.
def get_ordering_fields(self, use_manual_sorting: bool = False):
use_manual_sorting = use_manual_sorting and bool(self.manual_fields)
fields = (self.manual_fields if use_manual_sorting else self.sql_fields).keys()
return list(to_camelcase(word) for word in fields)

def _is_sql_field(self, field: str):
return field in self.sql_fields
def get_ordering_fields(self, count: int = 0):
"""
Returns list of fields that this Ordering class supports for sorting, which includes
encrypted fields only if provided record count is below ENCRYPTED_ORDERING_LIMIT.
def _is_manual_field(self, field: str):
return field in self.manual_fields
Returns field names in camelCase suitable for API output.
"""
fields = [*self.fields.keys()]
include_encrypted_fields = bool(self.encrypted_fields) and count < ENCRYPTED_ORDERING_LIMIT
if include_encrypted_fields:
fields += self.encrypted_fields.keys()
return list(to_camelcase(word) for word in fields)

def _parse_ordered_field(self, value: str):
if not value:
Expand All @@ -138,61 +148,7 @@ def _parse_ordered_field(self, value: str):
field = field.lstrip("+")
return (direction, field)

def _prepare_sql_clause(self, value: tuple[OrderingDirection, str]):
direction, field = value
if not self._is_sql_field(field):
return None
def _prepare_sql_clause(self, value: tuple[OrderingDirection, InstrumentedAttribute | Column | Clause]):
direction, expression = value

return self.actions[direction](self.sql_fields[field])

def manual_sort(self, data: list, fields: list[tuple[OrderingDirection, OrderingField]]):
def comparer(left, right):
for direction, field in fields:
left_value = right_value = None
# Extract array field value
if isinstance(field, tuple):
field, index = field
left_value = left[field][index]
right_value = right[field][index]
# Extract scalar field value
else:
left_value = left[field]
right_value = right[field]

# Protect against None values (None ranks lowest in SQL by default)
if left_value is None and right_value is None:
continue
elif left_value is None:
# Set left value to maximum possible value for the type
if isinstance(right_value, (int, float)):
left_value = float("inf") # maximum number
elif isinstance(right_value, str):
left_value = chr(1114111) # maximum string
elif right_value is None:
# Set right value to maximum possible value for the type
if isinstance(left_value, (int, float)):
right_value = float("inf") # maximum number
elif isinstance(left_value, str):
right_value = chr(1114111) # maximum string

result = 0
# Need to do conditional type casting strictly to satisfy type checker
if isinstance(left_value, (int, float)):
right_value = cast(float, right_value)
result = (left_value > right_value) - (left_value < right_value)
elif isinstance(left_value, str):
right_value = cast(str, right_value)
# Convert to lowercase for case-insensitive ordering
left_value = left_value.lower()
right_value = right_value.lower()
result = (left_value > right_value) - (left_value < right_value)

multiplier = -1 if direction == "-" else 1
if result:
return multiplier * result
return 0

return sorted(data, key=cmp_to_key(comparer))

def manual_paginate(self, data: list, page_number: int, page_size: int):
return data[(page_number - 1) * page_size : page_number * page_size]
return self.actions[direction](expression)
81 changes: 19 additions & 62 deletions src/apps/workspaces/crud/user_applet_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,6 @@
__all__ = ["UserAppletAccessCRUD"]


# Record limits for manual sorting of respondents and managers based on resource use
RESPONDENTS_MANUAL_SORT_LIMIT = 200
MANAGERS_MANUAL_SORT_LIMIT = 200


class _AppletUsersFilter(Filtering):
role = FilterField(UserAppletAccessSchema.role)
shell = FilterField(UserSchema.id, method_name="null")
Expand All @@ -72,15 +67,13 @@ class _WorkspaceRespondentOrdering(Ordering):
# TODO: https://mindlogger.atlassian.net/browse/M2-6834
# Add support to order by last_seen

# Because nickname is encrypted, we need to support manual sorting by that field as well as any
# other field (in order to support manual sorting by multiple keys)
manual_fields = {
"nicknames": ("nicknames", 0),
"is_pinned": "is_pinned",
"secret_ids": ("secret_ids", 0),
"tags": ("tags_order", 0),
"created_at": "created_at",
"status": "status_order",
encrypted_fields = {
"nicknames": Ordering.Clause(
func.array_remove(
func.array_agg(func.distinct(func.decrypt_internal(SubjectSchema.nickname, get_key()))),
None,
)
)
}


Expand All @@ -104,16 +97,10 @@ class _AppletManagersOrdering(Ordering):
last_seen = Ordering.Clause(literal_column("last_seen"))
roles = Ordering.Clause(literal_column("roles"))

# Because email, first name and last name are encrypted, we need to support manual sorting by
# those fields as well as any other field (in order to support manual sorting by multiple keys)
manual_fields = {
"email": "email_encrypted",
"first_name": "first_name",
"last_name": "last_name",
"created_at": "created_at",
"is_pinned": "is_pinned",
"last_seen": "last_seen",
"roles": ("roles", 0),
encrypted_fields = {
"email": Ordering.Clause(func.decrypt_internal(UserSchema.first_name, get_key())),
"first_name": Ordering.Clause(func.decrypt_internal(UserSchema.first_name, get_key())),
"last_name": Ordering.Clause(func.decrypt_internal(UserSchema.last_name, get_key())),
}


Expand Down Expand Up @@ -573,30 +560,15 @@ async def get_workspace_respondents(
total = (await coro_total).scalar()

ordering = _WorkspaceRespondentOrdering()
manual_order_fields = None

if query_params.ordering:
if total < RESPONDENTS_MANUAL_SORT_LIMIT:
manual_order_fields = ordering.get_manual_fields(*query_params.ordering)
# Only perform SQL ORDER BY if all requested ordering fields can be done by SQL
if not manual_order_fields:
query = query.order_by(*ordering.get_clauses(*query_params.ordering))

# If able to use SQL ordering, also use SQL-based paging; else paginate post-execute
if not manual_order_fields:
query = paging(query, query_params.page, query_params.limit)
query = query.order_by(*ordering.get_clauses(*query_params.ordering, count=total))

res_data = await self._execute(query)

data = res_data.all()

# If sorting manually, both sort and paginate manually
if manual_order_fields:
data = ordering.manual_sort(data, manual_order_fields)
data = ordering.manual_paginate(data, query_params.page, query_params.limit)
query = paging(query, query_params.page, query_params.limit)

data = (await self._execute(query)).all()
data = parse_obj_as(list[WorkspaceRespondent], data)
ordering_fields = ordering.get_ordering_fields(total < RESPONDENTS_MANUAL_SORT_LIMIT)
ordering_fields = ordering.get_ordering_fields(total)

return data, total, ordering_fields

Expand Down Expand Up @@ -696,30 +668,15 @@ async def get_workspace_managers(
total = (await coro_total).scalar()

ordering = _AppletManagersOrdering()
manual_order_fields = None

if query_params.ordering:
if total < MANAGERS_MANUAL_SORT_LIMIT:
manual_order_fields = ordering.get_manual_fields(*query_params.ordering)
# Only perform SQL ORDER BY if all requested ordering fields can be done by SQL
if not manual_order_fields:
query = query.order_by(*ordering.get_clauses(*query_params.ordering))

# If able to use SQL ordering, also use SQL-based paging; else paginate post-execute
if not manual_order_fields:
query = paging(query, query_params.page, query_params.limit)

res_data = await self._execute(query)

data = res_data.all()
query = query.order_by(*ordering.get_clauses(*query_params.ordering, count=total))

# If sorting manually, both sort and paginate manually
if manual_order_fields:
data = ordering.manual_sort(data, manual_order_fields)
data = ordering.manual_paginate(data, query_params.page, query_params.limit)
query = paging(query, query_params.page, query_params.limit)

data = (await self._execute(query)).all()
data = parse_obj_as(list[WorkspaceManager], data)
ordering_fields = ordering.get_ordering_fields(total < MANAGERS_MANUAL_SORT_LIMIT)
ordering_fields = ordering.get_ordering_fields(total)

# TODO: Fix via class Searching
# using database fields - StringEncryptedType
Expand Down

0 comments on commit b406b29

Please sign in to comment.