Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a basic Django model generator. #556

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 90 additions & 15 deletions gel/_testbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import inspect
import json
import logging
import pathlib
import os
import re
import subprocess
Expand All @@ -37,7 +38,8 @@
from gel import asyncio_client
from gel import blocking_client
from gel.orm.introspection import get_schema_json
from gel.orm.sqla import ModelGenerator
from gel.orm.sqla import ModelGenerator as SQLAModGen
from gel.orm.django.generator import ModelGenerator as DjangoModGen


log = logging.getLogger(__name__)
Expand Down Expand Up @@ -630,13 +632,13 @@ def adapt_call(cls, result):
return result


class SQLATestCase(SyncQueryTestCase):
SQLAPACKAGE = None
class ORMTestCase(SyncQueryTestCase):
MODEL_PACKAGE = None
DEFAULT_MODULE = 'default'

@classmethod
def setUpClass(cls):
# SQLAlchemy relies on psycopg2 to connect to Postgres and thus we
# ORMs rely on psycopg2 to connect to Postgres and thus we
# need it to run tests. Unfortunately not all test environemnts might
# have psycopg2 installed, as long as we run this in the test
# environments that have this, it is fine since we're not expecting
Expand All @@ -648,24 +650,34 @@ def setUpClass(cls):

class_set_up = os.environ.get('EDGEDB_TEST_CASES_SET_UP')
if not class_set_up:
# Now that the DB is setup, generate the SQLAlchemy models from it
spec = get_schema_json(cls.client)
# We'll need a temp directory to setup the generated Python
# package
cls.tmpsqladir = tempfile.TemporaryDirectory()
gen = ModelGenerator(
outdir=os.path.join(cls.tmpsqladir.name, cls.SQLAPACKAGE),
basemodule=cls.SQLAPACKAGE,
)
gen.render_models(spec)
sys.path.append(cls.tmpsqladir.name)
cls.tmpormdir = tempfile.TemporaryDirectory()
sys.path.append(cls.tmpormdir.name)
# Now that the DB is setup, generate the ORM models from it
cls.spec = get_schema_json(cls.client)
cls.setupORM()

@classmethod
def setupORM(cls):
raise NotImplementedError

@classmethod
def tearDownClass(cls):
super().tearDownClass()
# cleanup the temp modules
sys.path.remove(cls.tmpsqladir.name)
cls.tmpsqladir.cleanup()
sys.path.remove(cls.tmpormdir.name)
cls.tmpormdir.cleanup()


class SQLATestCase(ORMTestCase):
@classmethod
def setupORM(cls):
gen = SQLAModGen(
outdir=os.path.join(cls.tmpormdir.name, cls.MODEL_PACKAGE),
basemodule=cls.MODEL_PACKAGE,
)
gen.render_models(cls.spec)

@classmethod
def get_dsn_for_sqla(cls):
Expand All @@ -678,6 +690,69 @@ def get_dsn_for_sqla(cls):
return dsn


APPS_PY = '''\
from django.apps import AppConfig


class TestConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = {name!r}
'''

SETTINGS_PY = '''\
from pathlib import Path

mysettings = dict(
INSTALLED_APPS=[
'{appname}.apps.TestConfig',
'gel.orm.django.gelmodels.apps.GelPGModel',
],
DATABASES={{
'default': {{
'ENGINE': 'django.db.backends.postgresql',
'NAME': {database!r},
'USER': {user!r},
'PASSWORD': {password!r},
'HOST': {host!r},
'PORT': {port!r},
}}
}},
)
'''


class DjangoTestCase(ORMTestCase):
@classmethod
def setupORM(cls):
pkgbase = os.path.join(cls.tmpormdir.name, cls.MODEL_PACKAGE)
# Set up the package for testing Django models
os.mkdir(pkgbase)
open(os.path.join(pkgbase, '__init__.py'), 'w').close()
with open(os.path.join(pkgbase, 'apps.py'), 'wt') as f:
print(
APPS_PY.format(name=cls.MODEL_PACKAGE),
file=f,
)

with open(os.path.join(pkgbase, 'settings.py'), 'wt') as f:
cargs = cls.get_connect_args(database=cls.get_database_name())
print(
SETTINGS_PY.format(
appname=cls.MODEL_PACKAGE,
database=cargs["database"],
user=cargs["user"],
password=cargs["password"],
host=cargs["host"],
port=cargs["port"],
),
file=f,
)

models = os.path.join(pkgbase, 'models.py')
gen = DjangoModGen(out=models)
gen.render_models(cls.spec)


_lock_cnt = 0


Expand Down
7 changes: 7 additions & 0 deletions gel/orm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import unittest

# No tests here, but we want to skip the unittest loader from attempting to
# import ORM packages which may not have been installed (like Django that has
# a few custom adjustments to make our models work).
def load_tests(loader, tests, pattern):
return tests
18 changes: 14 additions & 4 deletions gel/orm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@

from gel.codegen.generator import _get_conn_args
from .introspection import get_schema_json
from .sqla import ModelGenerator
from .sqla import ModelGenerator as SQLAModGen
from .django.generator import ModelGenerator as DjangoModGen


class ArgumentParser(argparse.ArgumentParser):
Expand Down Expand Up @@ -65,7 +66,6 @@ def error(self, message):
"--mod",
help="The fullname of the Python module corresponding to the output "
"directory.",
required=True,
)


Expand All @@ -74,13 +74,23 @@ def main():
# setup client
client = gel.create_client(**_get_conn_args(args))
spec = get_schema_json(client)
generate_models(args, spec)


def generate_models(args, spec):
match args.orm:
case 'sqlalchemy':
gen = ModelGenerator(
if args.mod is None:
parser.error('sqlalchemy requires to specify --mod')

gen = SQLAModGen(
outdir=args.out,
basemodule=args.mod,
)
gen.render_models(spec)

case 'django':
print('Not available yet. Coming soon!')
gen = DjangoModGen(
out=args.out,
)
gen.render_models(spec)
Empty file added gel/orm/django/__init__.py
Empty file.
3 changes: 3 additions & 0 deletions gel/orm/django/gelmodels/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import django

__version__ = "0.0.1"
31 changes: 31 additions & 0 deletions gel/orm/django/gelmodels/apps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from django.apps import AppConfig


class GelPGModel(AppConfig):
name = "gel.orm.django.gelmodels"

def ready(self):
from django.db import connections, utils

gel_compiler_module = "gel.orm.django.gelmodels.compiler"

# Change the current compiler_module
for c in connections:
connections[c].ops.compiler_module = gel_compiler_module

# Update the load_backend to use our DatabaseWrapper
orig_load_backend = utils.load_backend

def custom_load_backend(*args, **kwargs):
backend = orig_load_backend(*args, **kwargs)

class GelPGBackend:
@staticmethod
def DatabaseWrapper(*args2, **kwargs2):
connection = backend.DatabaseWrapper(*args2, **kwargs2)
connection.ops.compiler_module = gel_compiler_module
return connection

return GelPGBackend

utils.load_backend = custom_load_backend
64 changes: 64 additions & 0 deletions gel/orm/django/gelmodels/compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from django.db.models.sql.compiler import ( # noqa
SQLAggregateCompiler,
SQLCompiler,
SQLDeleteCompiler,
)
from django.db.models.sql.compiler import ( # noqa
SQLInsertCompiler as BaseSQLInsertCompiler,
)
from django.db.models.sql.compiler import ( # noqa
SQLUpdateCompiler as BaseSQLUpdateCompiler,
)


class GelSQLCompilerMixin:
'''
The reflected models have two special fields: `id` and `obj_type`. Both of
those fields should be read-only as they are populated automatically by
Gel and must not be modified.
'''
@property
def readonly_gel_fields(self):
try:
# Verify that this is a Gel model reflected via Postgres protocol.
gel_pg_meta = getattr(self.query.model, "GelPGMeta")
except AttributeError:
return set()
else:
return {'id', 'gel_type_id'}

def as_sql(self):
readonly_gel_fields = self.readonly_gel_fields
if readonly_gel_fields:
self.remove_readonly_gel_fields(readonly_gel_fields)
return super().as_sql()


class SQLUpdateCompiler(GelSQLCompilerMixin, BaseSQLUpdateCompiler):
def remove_readonly_gel_fields(self, names):
'''
Remove the values corresponding to the read-only fields.
'''
values = self.query.values
# The tuple is (field, model, value)
values[:] = (tup for tup in values if tup[0].name not in names)


class SQLInsertCompiler(GelSQLCompilerMixin, BaseSQLInsertCompiler):
def remove_readonly_gel_fields(self, names):
'''
Remove the read-only fields.
'''
fields = self.query.fields

try:
fields[:] = (f for f in fields if f.name not in names)
except AttributeError:
# When deserializing, we might get an attribute error because this
# list shoud be copied first:
#
# "AttributeError: The return type of 'local_concrete_fields'
# should never be mutated. If you want to manipulate this list for
# your own use, make a copy first."

self.query.fields = [f for f in fields if f.name not in names]
Loading
Loading