Skip to content

Commit

Permalink
change: move to pycasbin with async stubs
Browse files Browse the repository at this point in the history
fix: adapter returns
  • Loading branch information
thearchitector committed Dec 30, 2023
1 parent f3da125 commit b94f0eb
Show file tree
Hide file tree
Showing 10 changed files with 257 additions and 151 deletions.
7 changes: 1 addition & 6 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,10 @@ on: [push, pull_request]
jobs:
test:
runs-on: ubuntu-latest
continue-on-error: ${{ matrix.experimental }}
strategy:
fail-fast: true
matrix:
python-version: [3.7, 3.8, 3.9, "3.10", "3.11"]
experimental: [false]
include:
- python-version: "3.12"
experimental: true
python-version: [3.7, 3.8, 3.9, "3.10", "3.11", "3.12"]

services:
postgres:
Expand Down
14 changes: 7 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
repos:
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.1.6
rev: v0.1.9
hooks:
- id: ruff
args: ["--fix", "--exit-non-zero-on-fix"]
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.7.1
- repo: local
hooks:
- id: mypy
name: mypy
language: system
entry: .venv/Scripts/python.exe -m mypy
types: [python]
exclude: ^tests
require_serial: true
args: ["--explicit-package-bases", "--check-untyped-defs"]
additional_dependencies:
["asynccasbin<2.0.0,>=1.1.2", "tortoise-orm[accel]>=0.18.0"]
- repo: https://github.com/pdm-project/pdm
rev: 2.10.4
rev: 2.11.1
hooks:
- id: pdm-lock-check
- id: pdm-export
Expand Down
11 changes: 3 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,7 @@ python3 -m pip install --user casbin-tortoise-adapter
# or via your favorite dependency manager, like PDM
```

The current supported databases are [limited by Tortoise ORM](https://tortoise.github.io/databases.html), and include:

- PostgreSQL >= 9.4 (using `asyncpg`)
- SQLite (using `aiosqlite`)
- MySQL/MariaDB (using `asyncmy`)
- Microsoft SQL Server / Oracle (using `asyncodbc`)
The current supported databases are [limited by Tortoise ORM](https://tortoise.github.io/databases.html).

## Documentation

Expand All @@ -32,7 +27,7 @@ A custom Model, combined with advanced configuration like show in the Tortoise O
## Basic example

```python
from casbin import Enforcer
from casbin import AsyncEnforcer
from tortoise import Tortoise

from casbin_tortoise_adapter import CasbinRule, TortoiseAdapter
Expand All @@ -46,7 +41,7 @@ async def main()
await Tortoise.generate_schemas()

adapter = casbin_tortoise_adapter.TortoiseAdapter()
e = casbin.Enforcer('path/to/model.conf', adapter, True)
e = AsyncEnforcer('path/to/model.conf', adapter)

sub = "alice" # the user that wants to access a resource.
obj = "data1" # the resource that is going to be accessed.
Expand Down
93 changes: 62 additions & 31 deletions casbin_tortoise_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from dataclasses import asdict
from typing import TYPE_CHECKING

from casbin.persist import (
Adapter,
BatchAdapter,
FilteredAdapter,
load_policy_line, # pyright: ignore
from casbin.persist import load_policy_line
from casbin.persist.adapters.asyncio import (
AsyncAdapter,
AsyncBatchAdapter,
AsyncFilteredAdapter,
AsyncUpdateAdapter,
)
from casbin.persist.adapters.update_adapter import UpdateAdapter
from tortoise.expressions import Q

from .model import CasbinRule
Expand All @@ -27,26 +27,26 @@ class Assertion:
policy: List[RuleType]


class TortoiseAdapter(BatchAdapter, UpdateAdapter, FilteredAdapter, Adapter):
class TortoiseAdapter(
AsyncBatchAdapter, AsyncUpdateAdapter, AsyncFilteredAdapter, AsyncAdapter
):
"""An async Casbin adapter for Tortoise ORM."""

def __init__(self, modelclass: Type[CasbinRule] = CasbinRule) -> None:
if not issubclass(modelclass, CasbinRule): # pyright: ignore
if not issubclass(modelclass, CasbinRule):
raise TypeError(
"The provided model class must be a subclass of CasbinRule!"
)

self.modelclass: Type[CasbinRule] = modelclass
self._filtered: bool = False

async def load_policy(self, model: Model) -> None: # pyright: ignore
async def load_policy(self, model: Model) -> None:
"""Loads all policy rules from storage."""
for line in await self.modelclass.all():
load_policy_line(str(line), model)

async def load_filtered_policy( # pyright: ignore
self, model: Model, filter: RuleFilter
) -> None:
async def load_filtered_policy(self, model: Model, filter: RuleFilter) -> None:
"""Loads all policy rules that match the filter from storage."""
rules = await self.modelclass.filter(
**{f"{f}__in": v for f, v in asdict(filter).items() if v}
Expand All @@ -57,9 +57,9 @@ async def load_filtered_policy( # pyright: ignore

self._filtered = True

async def save_policy(self, model: Model) -> None: # pyright: ignore
async def save_policy(self, model: Model) -> None:
"""Saves all policy rules to storage."""
raw: Dict[str, Dict[str, Assertion]] = model.model # pyright: ignore
raw: Dict[str, Dict[str, Assertion]] = model.model
rules: List[CasbinRule] = [
self._to_rule(ptype, rule)
for sec in ("p", "g")
Expand All @@ -71,46 +71,76 @@ async def save_policy(self, model: Model) -> None: # pyright: ignore

async def add_policy( # pyright: ignore
self, sec: str, ptype: str, rule: RuleType
) -> None:
) -> bool:
"""Saves a policy rule to storage."""
await self._to_rule(ptype, rule).save()
return True

async def add_policies( # pyright: ignore
self, sec: str, ptype: str, rules: List[RuleType]
) -> None:
) -> bool:
"""Saves policy rules to storage."""
batch = [self._to_rule(ptype, rule) for rule in rules]
await self.modelclass.bulk_create(batch)
rs: List[CasbinRule] = await self.modelclass.bulk_create(batch) # type: ignore
return len(rs) > 0

async def update_policy( # pyright: ignore
self, sec: str, ptype: str, old_rule: RuleType, new_policy: RuleType
) -> None:
) -> bool:
"""
Updates a policy rule from storage. This is part of the Auto-Save feature.
"""
vs = {f"v{i}": rule for i, rule in enumerate(old_rule)}
r = self.modelclass.filter(ptype=ptype, **vs)
await r.update(
r = await self.modelclass.filter(ptype=ptype, **vs).update(
**{
f"v{i}": (new_policy[i] if i < len(new_policy) else None)
for i in range(6)
}
)
return r > 0

async def update_policies(
async def update_policies( # pyright: ignore
self,
sec: str,
ptype: str,
old_rules: List[RuleType],
new_rules: List[RuleType],
) -> None:
) -> bool:
"""Updates the old rules with the new rules."""
await asyncio.gather(
if not old_rules or not new_rules or (len(old_rules) != len(new_rules)):
raise ValueError(
"There must be at least one mapped pair of old and new rules."
)

rs = await asyncio.gather(
*[
self.update_policy(sec, ptype, old_rule, new_rule)
for old_rule, new_rule in zip(old_rules, new_rules)
]
)
return all(rs)

async def update_filtered_policies( # pyright: ignore
self,
sec: str,
ptype: str,
new_rules: List[RuleType],
field_index: int,
*field_values: Tuple[str],
) -> List[RuleType]:
"""Updates the old filtered rules with the new rules."""
if not (0 <= field_index <= 5) or not (
1 <= field_index + len(field_values) <= 6
):
return []

vs = {f"v{field_index + i}": v for i, v in enumerate(field_values) if v}
rs = await self.modelclass.filter(**vs).all()
old_rules = [self._from_rule(r) for r in rs]

await self.update_policies(sec, ptype, old_rules, new_rules)

return old_rules

async def remove_policy( # pyright: ignore
self, sec: str, ptype: str, rule: RuleType
Expand All @@ -132,27 +162,28 @@ async def remove_filtered_policy( # pyright: ignore
):
return False

r = 0
vs = {f"v{field_index + i}": v for i, v in enumerate(field_values) if v}
if vs:
r = await self.modelclass.filter(**vs).delete()

r = await self.modelclass.filter(**vs).delete()
return r > 0

async def remove_policies( # pyright: ignore
self, sec: str, ptype: str, rules: List[RuleType]
) -> None:
) -> bool:
"""Removes policy rules from storage."""
if not rules:
return
return False

qs = [Q(**{f"v{i}": v for i, v in enumerate(rule)}) for rule in rules]
await self.modelclass.filter(Q(*qs, join_type=Q.OR), ptype=ptype).delete()
r = await self.modelclass.filter(Q(*qs, join_type=Q.OR), ptype=ptype).delete()
return r > 0

def is_filtered(self) -> bool:
def is_filtered(self) -> bool: # pyright: ignore
"""Returns if the loaded policy is filtered or not."""
return self._filtered

def _to_rule(self, ptype: str, rule: RuleType) -> CasbinRule:
kwargs: Dict[str, str] = {f"v{i}": v for i, v in enumerate(rule)}
return self.modelclass(ptype=ptype, **kwargs)

def _from_rule(self, rule: CasbinRule) -> RuleType:
return str(rule).split(", ")[1:]
2 changes: 1 addition & 1 deletion docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ services:
POSTGRES_DB: "casbin_rule"
healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres"]
interval: 10s
interval: 1s
timeout: 5s
retries: 5

Expand Down
Loading

0 comments on commit b94f0eb

Please sign in to comment.