|
| 1 | +from sqlorm import Engine, Model as _Model, get_current_session, init_db, migrate, create_all |
| 2 | +from sqlorm.engine import session_context |
| 3 | +import sqlorm |
| 4 | +import abc |
| 5 | +import os |
| 6 | +import click |
| 7 | +from flask import g, abort, has_request_context |
| 8 | +from flask.cli import AppGroup |
| 9 | +from werkzeug.local import LocalProxy |
| 10 | + |
| 11 | + |
| 12 | +class FlaskSQLORM: |
| 13 | + def __init__(self, app=None, *args, **kwargs): |
| 14 | + if app: |
| 15 | + self.init_app(app, *args, **kwargs) |
| 16 | + |
| 17 | + def init_app(self, app, database_uri="sqlite://:memory:", **engine_kwargs): |
| 18 | + self.app = app |
| 19 | + |
| 20 | + for key in dir(sqlorm): |
| 21 | + if not key.startswith("_") and not hasattr(self, key): |
| 22 | + setattr(self, key, getattr(sqlorm, key)) |
| 23 | + |
| 24 | + config = app.config.get_namespace("SQLORM_") |
| 25 | + database_uri = config.pop("uri", database_uri) |
| 26 | + for key, value in engine_kwargs.items(): |
| 27 | + config.setdefault(key, value) |
| 28 | + config.setdefault("logger", app.logger) |
| 29 | + if database_uri.startswith("sqlite://"): |
| 30 | + config.setdefault("fine_tune", True) |
| 31 | + |
| 32 | + self.engine = Engine.from_uri(database_uri, **config) |
| 33 | + self.session = LocalProxy(get_current_session) |
| 34 | + self.Model = Model.bind(self.engine) |
| 35 | + |
| 36 | + @app.before_request |
| 37 | + def start_db_session(): |
| 38 | + g.sqlorm_session = self.engine.make_session() |
| 39 | + session_context.push(g.sqlorm_session) |
| 40 | + |
| 41 | + @app.after_request |
| 42 | + def close_db_session(response): |
| 43 | + session_context.pop() |
| 44 | + g.sqlorm_session.close() |
| 45 | + return response |
| 46 | + |
| 47 | + cli = AppGroup("db", help="Commands to manage your database") |
| 48 | + |
| 49 | + @cli.command() |
| 50 | + def init(): |
| 51 | + """Initializes the database, either creating tables for models or running migrations if some exists""" |
| 52 | + self.init_db() |
| 53 | + |
| 54 | + @cli.command() |
| 55 | + def create_all(): |
| 56 | + """Create all tables associated to models""" |
| 57 | + self.create_all() |
| 58 | + |
| 59 | + @cli.command() |
| 60 | + @click.option("--from", "from_", type=int) |
| 61 | + @click.option("--to", type=int) |
| 62 | + @click.option("--dryrun", is_flag=True) |
| 63 | + @click.option("--ignore-schema-version", is_flag=True) |
| 64 | + def migrate(from_, to, dryrun, ignore_schema_version): |
| 65 | + """Run database migrations from the migrations folder in your app root path""" |
| 66 | + self.migrate(from_version=from_, to_version=to, dryrun=dryrun, use_schema_version=not ignore_schema_version) |
| 67 | + |
| 68 | + app.cli.add_command(cli) |
| 69 | + |
| 70 | + def __enter__(self): |
| 71 | + if has_request_context(): |
| 72 | + return g.sqlorm_session.__enter__() |
| 73 | + return self.engine.__enter__() |
| 74 | + |
| 75 | + def __exit__(self, exc_type, exc_value, exc_tb): |
| 76 | + if has_request_context(): |
| 77 | + g.sqlorm_session.__exit__(exc_type, exc_value, exc_tb) |
| 78 | + else: |
| 79 | + self.engine.__exit__(exc_type, exc_value, exc_tb) |
| 80 | + |
| 81 | + def create_all(self, **kwargs): |
| 82 | + kwargs.setdefault("model_registry", self.Model.__model_registry__) |
| 83 | + with self.engine: |
| 84 | + create_all(**kwargs) |
| 85 | + |
| 86 | + def init_db(self, **kwargs): |
| 87 | + kwargs.setdefault("path", os.path.join(self.app.root_path, "migrations")) |
| 88 | + kwargs.setdefault("model_registry", self.Model.__model_registry__) |
| 89 | + kwargs.setdefault("logger", self.app.logger) |
| 90 | + with self.engine: |
| 91 | + init_db(**kwargs) |
| 92 | + |
| 93 | + def migrate(self, **kwargs): |
| 94 | + kwargs.setdefault("path", os.path.join(self.app.root_path, "migrations")) |
| 95 | + kwargs.setdefault("logger", self.app.logger) |
| 96 | + with self.engine: |
| 97 | + migrate(**kwargs) |
| 98 | + |
| 99 | + |
| 100 | +class Model(_Model, abc.ABC): |
| 101 | + @classmethod |
| 102 | + def find_one_or_404(cls, *args, **kwargs): |
| 103 | + obj = cls.find_one(*args, **kwargs) |
| 104 | + if not obj: |
| 105 | + abort(404) |
| 106 | + return obj |
| 107 | + |
| 108 | + @classmethod |
| 109 | + def get_or_404(cls, *args, **kwargs): |
| 110 | + obj = cls.get(*args, **kwargs) |
| 111 | + if not obj: |
| 112 | + abort(404) |
| 113 | + return obj |
0 commit comments