Skip to content

Commit

Permalink
Update the backlink naming scheme for ORM generators.
Browse files Browse the repository at this point in the history
The backlinks are going to have simpler and shorter names mimicking the
EdgeQL `.<link[is Type]` by using the `_link_Type` naming format.
  • Loading branch information
vpetrovykh committed Feb 14, 2025
1 parent 3a51966 commit b65440c
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 97 deletions.
2 changes: 1 addition & 1 deletion gel/orm/django/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def table(self):
return self.meta['db_table'].strip("'")

def get_backlink_name(self, name, srcname):
return self.backlink_renames.get(name, f'back_to_{srcname}')
return f'_{name}_{srcname}'


class ModelGenerator(FilePrinter):
Expand Down
29 changes: 2 additions & 27 deletions gel/orm/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ def _process_links(types, modules):

objtype = type_map[target]
objtype['backlinks'].append({
'name': f'back_to_{sql_source}',
# naming scheme mimics .<link[is Type]
'name': f'_{sql_name}_{sql_source}',
'fwname': sql_name,
# flip cardinality and exclusivity
'cardinality': 'One' if exclusive else 'Many',
Expand All @@ -198,7 +199,6 @@ def _process_links(types, modules):
'has_link_object': False,
})


link['has_link_object'] = False
# Any link with properties should become its own intermediate
# object, since ORMs generally don't have a special convenient
Expand Down Expand Up @@ -232,31 +232,6 @@ def _process_links(types, modules):
'target': target,
})

# Go over backlinks and resolve any name collisions using the type map.
for spec in types:
mod = spec["name"].rsplit('::', 1)[0]
sql_source = get_sql_name(spec["name"])

# Find collisions in backlink names
bk = collections.defaultdict(list)
for link in spec['backlinks']:
if link['name'].startswith('back_to_'):
bk[link['name']].append(link)

for bklinks in bk.values():
if len(bklinks) > 1:
# We have a collision, so each backlink in it must now be
# disambiguated.
for link in bklinks:
origsrc = get_sql_name(link['target']['name'])
lname = link['name']
fwname = link['fwname']
link['name'] = f'follow_{fwname}_{lname}'
# Also update the original source of the link with the
# special backlink name.
source = type_map[link['target']['name']]
source['backlink_renames'][fwname] = link['name']

return {
'modules': modules,
'object_types': types,
Expand Down
7 changes: 2 additions & 5 deletions gel/orm/sqla.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,7 @@ def render_link_object(self, spec, modules):
bklink = source_link
else:
src = modules[mod]['object_types'][source_name]
bklink = src['backlink_renames'].get(
source_link,
f'back_to_{source_name}',
)
bklink = f'_{source_link}_{source_name}'

self.write(
f'{lname}: orm.Mapped[{pyname}] = '
Expand Down Expand Up @@ -418,7 +415,7 @@ def render_link(self, spec, mod, parent, modules):
tmod, target = get_mod_and_name(spec['target']['name'])
source = modules[mod]['object_types'][parent]
cardinality = spec['cardinality']
bklink = source['backlink_renames'].get(name, f'back_to_{parent}')
bklink = f'_{name}_{parent}'

if spec.get('has_link_object'):
# intermediate object will have the actual source and target
Expand Down
7 changes: 2 additions & 5 deletions gel/orm/sqlmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,7 @@ def render_link_object(self, spec, modules):
bklink = source_link
else:
src = modules[mod]['object_types'][source_name]
bklink = src['backlink_renames'].get(
source_link,
f'back_to_{source_name}',
)
bklink = f'_{source_link}_{source_name}'

self.write(
f'{lname}: {pyname} = sm.Relationship(')
Expand Down Expand Up @@ -452,7 +449,7 @@ def render_link(self, spec, mod, parent, modules):
tmod, target = get_mod_and_name(spec['target']['name'])
source = modules[mod]['object_types'][parent]
cardinality = spec['cardinality']
bklink = source['backlink_renames'].get(name, f'back_to_{parent}')
bklink = f'_{name}_{parent}'

if tmod != 'default':
warnings.warn(
Expand Down
36 changes: 18 additions & 18 deletions tests/test_django_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_django_read_models_02(self):
# use backlink
res = self.m.User.objects.order_by('name').all()
vals = [
(u.name, {p.body for p in u.back_to_Post.all()})
(u.name, {p.body for p in u._author_Post.all()})
for u in res
]
self.assertEqual(
Expand Down Expand Up @@ -149,10 +149,10 @@ def test_django_read_models_03(self):
)

# prefetch via backlink
res = self.m.User.objects.prefetch_related('back_to_Post') \
.order_by('back_to_Post__body')
res = self.m.User.objects.prefetch_related('_author_Post') \
.order_by('_author_Post__body')
vals = {
(u.name, tuple(p.body for p in u.back_to_Post.all()))
(u.name, tuple(p.body for p in u._author_Post.all()))
for u in res
}
self.assertEqual(
Expand Down Expand Up @@ -184,7 +184,7 @@ def test_django_read_models_04(self):
# use backlink
res = self.m.User.objects.all()
vals = {
(u.name, tuple(g.num for g in u.back_to_GameSession.all()))
(u.name, tuple(g.num for g in u._players_GameSession.all()))
for u in res
}
self.assertEqual(
Expand Down Expand Up @@ -216,9 +216,9 @@ def test_django_read_models_05(self):
)

# prefetch via backlink
res = self.m.User.objects.prefetch_related('back_to_GameSession')
res = self.m.User.objects.prefetch_related('_players_GameSession')
vals = {
(u.name, tuple(g.num for g in u.back_to_GameSession.all()))
(u.name, tuple(g.num for g in u._players_GameSession.all()))
for u in res
}
self.assertEqual(
Expand Down Expand Up @@ -251,7 +251,7 @@ def test_django_read_models_06(self):
# use backlink
res = self.m.User.objects.order_by('name').all()
vals = [
(u.name, {g.name for g in u.back_to_UserGroup.all()})
(u.name, {g.name for g in u._users_UserGroup.all()})
for u in res
]
self.assertEqual(
Expand Down Expand Up @@ -284,9 +284,9 @@ def test_django_read_models_07(self):
)

# prefetch via backlink
res = self.m.User.objects.prefetch_related('back_to_UserGroup')
res = self.m.User.objects.prefetch_related('_users_UserGroup')
vals = {
(u.name, tuple(sorted(g.name for g in u.back_to_UserGroup.all())))
(u.name, tuple(sorted(g.name for g in u._users_UserGroup.all())))
for u in res
}
self.assertEqual(
Expand Down Expand Up @@ -347,7 +347,7 @@ def test_django_create_models_02(self):
user = self.m.User.objects.get(name=name)

self.assertEqual(user.name, name)
self.assertEqual(user.back_to_UserGroup.all()[0].name, 'cyan')
self.assertEqual(user._users_UserGroup.all()[0].name, 'cyan')
self.assertIsInstance(user.id, uuid.UUID)

def test_django_create_models_03(self):
Expand All @@ -359,8 +359,8 @@ def test_django_create_models_03(self):
y.save()
cyan.save()

x.back_to_UserGroup.add(cyan)
y.back_to_UserGroup.add(cyan)
x._users_UserGroup.add(cyan)
y._users_UserGroup.add(cyan)

group = self.m.UserGroup.objects.get(name='cyan')
self.assertEqual(group.name, 'cyan')
Expand Down Expand Up @@ -443,8 +443,8 @@ def test_django_delete_models_05(self):

group.delete()
# make sure the user object is no longer a link target
user.back_to_UserGroup.clear()
user.back_to_GameSession.clear()
user._users_UserGroup.clear()
user._players_GameSession.clear()
user.delete()

vals = self.m.UserGroup.objects.filter(name='green').all()
Expand Down Expand Up @@ -476,13 +476,13 @@ def test_django_update_models_02(self):
blue.users.add(user)

self.assertEqual(
{g.name for g in user.back_to_UserGroup.all()},
{g.name for g in user._users_UserGroup.all()},
{'red', 'blue'},
)
self.assertEqual(user.name, 'Yvonne')
self.assertIsInstance(user.id, uuid.UUID)

group = [g for g in user.back_to_UserGroup.all()
group = [g for g in user._users_UserGroup.all()
if g.name == 'red'][0]
self.assertEqual(
{u.name for u in group.users.all()},
Expand All @@ -493,7 +493,7 @@ def test_django_update_models_03(self):
user0 = self.m.User.objects.get(name='Elsa')
user1 = self.m.User.objects.get(name='Zoe')
# Replace the author or a post
post = user0.back_to_Post.all()[0]
post = user0._author_Post.all()[0]
body = post.body
post.author = user1
post.save()
Expand Down
36 changes: 18 additions & 18 deletions tests/test_sqla_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_sqla_read_models_02(self):
# use backlink
res = self.sess.query(self.sm.User).order_by(self.sm.User.name).all()
vals = [
(u.name, {p.body for p in u.back_to_Post})
(u.name, {p.body for p in u._author_Post})
for u in res
]
self.assertEqual(
Expand Down Expand Up @@ -149,13 +149,13 @@ def test_sqla_read_models_03(self):
# join via backlink
res = self.sess.execute(
select(self.sm.Post, self.sm.User)
.join(self.sm.User.back_to_Post)
.join(self.sm.User._author_Post)
.order_by(self.sm.Post.body)
)
# We'll get a cross-product, so we need to jump through some hoops to
# remove duplicates
vals = {
(u.name, tuple(p.body for p in u.back_to_Post))
(u.name, tuple(p.body for p in u._author_Post))
for (_, u) in res
}
self.assertEqual(
Expand All @@ -170,11 +170,11 @@ def test_sqla_read_models_03(self):
# LEFT OUTER join via backlink
res = self.sess.execute(
select(self.sm.Post, self.sm.User)
.join(self.sm.User.back_to_Post, isouter=True)
.join(self.sm.User._author_Post, isouter=True)
.order_by(self.sm.Post.body)
)
vals = {
(u.name, tuple(p.body for p in u.back_to_Post))
(u.name, tuple(p.body for p in u._author_Post))
for (p, u) in res
}
self.assertEqual(
Expand Down Expand Up @@ -209,7 +209,7 @@ def test_sqla_read_models_04(self):
# use backlink
res = self.sess.query(self.sm.User).all()
vals = {
(u.name, tuple(g.num for g in u.back_to_GameSession))
(u.name, tuple(g.num for g in u._players_GameSession))
for u in res
}
self.assertEqual(
Expand Down Expand Up @@ -248,10 +248,10 @@ def test_sqla_read_models_05(self):
# LEFT OUTER join via backlink
res = self.sess.execute(
select(self.sm.GameSession, self.sm.User)
.join(self.sm.User.back_to_GameSession, isouter=True)
.join(self.sm.User._players_GameSession, isouter=True)
)
vals = {
(u.name, tuple(g.num for g in u.back_to_GameSession))
(u.name, tuple(g.num for g in u._players_GameSession))
for (_, u) in res
}
self.assertEqual(
Expand Down Expand Up @@ -287,7 +287,7 @@ def test_sqla_read_models_06(self):
# use backlink
res = self.sess.query(self.sm.User).order_by(self.sm.User.name).all()
vals = [
(u.name, {g.name for g in u.back_to_UserGroup})
(u.name, {g.name for g in u._users_UserGroup})
for u in res
]
self.assertEqual(
Expand Down Expand Up @@ -328,10 +328,10 @@ def test_sqla_read_models_07(self):
# LEFT OUTER join via backlink
res = self.sess.execute(
select(self.sm.UserGroup, self.sm.User)
.join(self.sm.User.back_to_UserGroup, isouter=True)
.join(self.sm.User._users_UserGroup, isouter=True)
)
vals = {
(u.name, tuple(sorted(g.name for g in u.back_to_UserGroup)))
(u.name, tuple(sorted(g.name for g in u._users_UserGroup)))
for (_, u) in res
}
self.assertEqual(
Expand Down Expand Up @@ -400,16 +400,16 @@ def test_sqla_create_models_02(self):
user = self.sess.query(self.sm.User).filter_by(name=name).one()

self.assertEqual(user.name, name)
self.assertEqual(user.back_to_UserGroup[0].name, 'cyan')
self.assertEqual(user._users_UserGroup[0].name, 'cyan')
self.assertIsInstance(user.id, uuid.UUID)

def test_sqla_create_models_03(self):
user0 = self.sm.User(name='Yvonne')
user1 = self.sm.User(name='Xander')
cyan = self.sm.UserGroup(name='cyan')

user0.back_to_UserGroup.append(cyan)
user1.back_to_UserGroup.append(cyan)
user0._users_UserGroup.append(cyan)
user1._users_UserGroup.append(cyan)

self.sess.add(cyan)
self.sess.flush()
Expand All @@ -418,7 +418,7 @@ def test_sqla_create_models_03(self):
user = self.sess.query(self.sm.User).filter_by(name=name).one()

self.assertEqual(user.name, name)
self.assertEqual(user.back_to_UserGroup[0].name, 'cyan')
self.assertEqual(user._users_UserGroup[0].name, 'cyan')
self.assertIsInstance(user.id, uuid.UUID)

def test_sqla_create_models_04(self):
Expand Down Expand Up @@ -557,13 +557,13 @@ def test_sqla_update_models_02(self):
self.sess.flush()

self.assertEqual(
{g.name for g in user.back_to_UserGroup},
{g.name for g in user._users_UserGroup},
{'red', 'blue'},
)
self.assertEqual(user.name, 'Yvonne')
self.assertIsInstance(user.id, uuid.UUID)

group = [g for g in user.back_to_UserGroup if g.name == 'red'][0]
group = [g for g in user._users_UserGroup if g.name == 'red'][0]
self.assertEqual(
{u.name for u in group.users},
{'Alice', 'Billie', 'Cameron', 'Dana', 'Yvonne'},
Expand All @@ -573,7 +573,7 @@ def test_sqla_update_models_03(self):
user0 = self.sess.query(self.sm.User).filter_by(name='Elsa').one()
user1 = self.sess.query(self.sm.User).filter_by(name='Zoe').one()
# Replace the author or a post
post = user0.back_to_Post[0]
post = user0._author_Post[0]
body = post.body
post.author = user1

Expand Down
11 changes: 4 additions & 7 deletions tests/test_sqla_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,24 +224,21 @@ def test_sqla_bklink_01(self):

# only one link from Bar 123 to foo
self.assertEqual(
[obj.n for obj in foo.follow_foo_back_to_Bar],
[obj.n for obj in foo._foo_Bar],
[123],
)
# only one link from Who 456 to oof
self.assertEqual(
[obj.x for obj in oof.follow_foo_back_to_Who],
[obj.x for obj in oof._foo_Who],
[456],
)

# foo is linked via `many_foo` from both Bar and Who
self.assertEqual(
[obj.n for obj in foo.follow_many_foo_back_to_Bar],
[obj.n for obj in foo._many_foo_Bar],
[123],
)
self.assertEqual(
[
(obj.note, obj.source.x)
for obj in foo.follow_many_foo_back_to_Who
],
[(obj.note, obj.source.x) for obj in foo._many_foo_Who],
[('just one', 456)],
)
Loading

0 comments on commit b65440c

Please sign in to comment.