Skip to content

Commit

Permalink
Merge pull request #297 from wwwjfy/t261
Browse files Browse the repository at this point in the history
fixed #261, map model and db column names
  • Loading branch information
wwwjfy authored Aug 27, 2018
2 parents 3b11fd8 + 31a3878 commit 4588e84
Show file tree
Hide file tree
Showing 14 changed files with 171 additions and 73 deletions.
13 changes: 13 additions & 0 deletions docs/schema.rst
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,19 @@ more declarative. Instead of ``users.c.name``, you can now access the column by
available at ``User.__table__`` and ``Address.__table__``. You can use anything
that works in GINO core here.

.. note::

Column names can be different as a class property and database column.
For example, name can be declared as
``nickname = db.Column('name', db.Unicode(), default='noname')``. In this
example, ``User.nickname`` is used to access the column, while in database,
the column name is ``name``.

What's worth mentioning is where raw SQL statements are used, or
``TableClause`` is involved, like ``User.insert()``, the original name is
required to be used, because in this case, GINO has no knowledge about the
mappings.

.. tip::

``db.Model`` is a dynamically created parent class for your models. It is
Expand Down
60 changes: 38 additions & 22 deletions gino/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sqlalchemy.sql import ClauseElement

from . import json_support
from .declarative import Model
from .declarative import Model, InvertDict
from .exceptions import NoSuchRowError
from .loader import AliasLoader, ModelLoader

Expand Down Expand Up @@ -78,7 +78,7 @@ class UpdateRequest:
specific model instance and its database row.
"""
def __init__(self, instance):
def __init__(self, instance: 'CRUDModel'):
self._instance = instance
self._values = {}
self._props = {}
Expand All @@ -88,7 +88,7 @@ def __init__(self, instance):
try:
self._locator = instance.lookup()
except LookupError:
# apply() will fail anyway, but still allow updates()
# apply() will fail anyway, but still allow update()
pass

def _set(self, key, value):
Expand Down Expand Up @@ -124,7 +124,7 @@ async def apply(self, bind=None, timeout=DEFAULT):
json_updates = {}
for prop, value in self._props.items():
value = prop.save(self._instance, value)
updates = json_updates.setdefault(prop.column_name, {})
updates = json_updates.setdefault(prop.prop_name, {})
if self._literal:
updates[prop.name] = value
else:
Expand All @@ -133,26 +133,28 @@ async def apply(self, bind=None, timeout=DEFAULT):
elif not isinstance(value, ClauseElement):
value = sa.cast(value, sa.Unicode)
updates[sa.cast(prop.name, sa.Unicode)] = value
for column_name, updates in json_updates.items():
column = getattr(cls, column_name)
for prop_name, updates in json_updates.items():
prop = getattr(cls, prop_name)
from .dialects.asyncpg import JSONB
if isinstance(column.type, JSONB):
if isinstance(prop.type, JSONB):
if self._literal:
values[column_name] = column.concat(updates)
values[prop_name] = prop.concat(updates)
else:
values[column_name] = column.concat(
values[prop_name] = prop.concat(
sa.func.jsonb_build_object(
*itertools.chain(*updates.items())))
else:
raise TypeError('{} is not supported.'.format(column.type))
raise TypeError('{} is not supported to update json '
'properties in Gino. Please consider using '
'JSONB.'.format(prop.type))

opts = dict(return_model=False)
if timeout is not DEFAULT:
opts['timeout'] = timeout
clause = type(self._instance).update.where(
self._locator,
).values(
**values,
**self._instance._get_sa_values(values),
).returning(
*[getattr(cls, key) for key in values],
).execution_options(**opts)
Expand All @@ -161,7 +163,9 @@ async def apply(self, bind=None, timeout=DEFAULT):
row = await bind.first(clause)
if not row:
raise NoSuchRowError()
self._instance.__values__.update(row)
for k, v in row.items():
self._instance.__values__[
self._instance._column_name_map.invert_get(k)] = v
for prop in self._props:
prop.reload(self._instance)
return self
Expand Down Expand Up @@ -409,6 +413,7 @@ class CRUDModel(Model):
"""

_update_request_cls = UpdateRequest
_column_name_map = InvertDict()

def __init__(self, **values):
super().__init__()
Expand All @@ -421,10 +426,10 @@ def _init_table(cls, sub_cls):
for each_cls in sub_cls.__mro__[::-1]:
for k, v in each_cls.__dict__.items():
if isinstance(v, json_support.JSONProperty):
if not hasattr(sub_cls, v.column_name):
if not hasattr(sub_cls, v.prop_name):
raise AttributeError(
'Requires "{}" JSON[B] column.'.format(
v.column_name))
v.prop_name))
v.name = k
if rv is not None:
rv.__model__ = weakref.ref(sub_cls)
Expand All @@ -440,12 +445,12 @@ async def _create(self, bind=None, timeout=DEFAULT):
cls = type(self)
# noinspection PyUnresolvedReferences,PyProtectedMember
cls._check_abstract()
keys = set(self.__profile__.keys() if self.__profile__ else [])
for key in keys:
profile_keys = set(self.__profile__.keys() if self.__profile__ else [])
for key in profile_keys:
cls.__dict__.get(key).save(self)
# initialize default values
for key, prop in cls.__dict__.items():
if key in keys:
if key in profile_keys:
continue
if isinstance(prop, json_support.JSONProperty):
if prop.default is None or prop.after_get.method is not None:
Expand All @@ -458,15 +463,25 @@ async def _create(self, bind=None, timeout=DEFAULT):
if timeout is not DEFAULT:
opts['timeout'] = timeout
# noinspection PyArgumentList
q = cls.__table__.insert().values(**self.__values__).returning(
*cls).execution_options(**opts)
q = cls.__table__.insert().values(
**self._get_sa_values(self.__values__)
).returning(
*cls
).execution_options(**opts)
if bind is None:
bind = cls.__metadata__.bind
row = await bind.first(q)
self.__values__.update(row)
for k, v in row.items():
self.__values__[self._column_name_map.invert_get(k)] = v
self.__profile__ = None
return self

def _get_sa_values(self, instance_values: dict) -> dict:
values = {}
for k, v in instance_values.items():
values[self._column_name_map[k]] = v
return values

@classmethod
async def get(cls, ident, bind=None, timeout=DEFAULT):
"""
Expand Down Expand Up @@ -592,11 +607,12 @@ def to_dict(self):
"""
cls = type(self)
keys = set(c.name for c in cls)
# noinspection PyTypeChecker
keys = set(cls._column_name_map.invert_get(c.name) for c in cls)
for key, prop in cls.__dict__.items():
if isinstance(prop, json_support.JSONProperty):
keys.add(key)
keys.discard(prop.column_name)
keys.discard(prop.prop_name)
return dict((k, getattr(self, k)) for k in keys)

@classmethod
Expand Down
44 changes: 37 additions & 7 deletions gino/declarative.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,50 @@
import sqlalchemy as sa
from sqlalchemy.exc import InvalidRequestError

from .exceptions import GinoException


class ColumnAttribute:
def __init__(self, column):
self.name = column.name
def __init__(self, prop_name, column):
self.prop_name = prop_name
self.column = column

def __get__(self, instance, owner):
if instance is None:
return self.column
else:
return instance.__values__.get(self.name)
return instance.__values__.get(self.prop_name)

def __set__(self, instance, value):
instance.__values__[self.name] = value
instance.__values__[self.prop_name] = value

def __delete__(self, instance):
raise AttributeError('Cannot delete value.')


class InvertDict(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._inverted_dict = dict()
for k, v in self.items():
if v in self._inverted_dict:
raise GinoException(
'Column name {} already maps to {}'.format(
v, self._inverted_dict[v]))
self._inverted_dict[v] = k

def __setitem__(self, key, value):
if value in self._inverted_dict and self._inverted_dict[value] != key:
raise GinoException(
'Column name {} already maps to {}'.format(
value, self._inverted_dict[value]))
super().__setitem__(key, value)
self._inverted_dict[value] = key

def invert_get(self, key, default=None):
return self._inverted_dict.get(key, default)


class ModelType(type):
def _check_abstract(self):
if self.__table__ is None:
Expand Down Expand Up @@ -119,18 +144,22 @@ def _init_table(cls, sub_cls):
columns = []
inspected_args = []
updates = {}
column_name_map = InvertDict()
for each_cls in sub_cls.__mro__[::-1]:
for k, v in getattr(each_cls, '__namespace__',
each_cls.__dict__).items():
if callable(v) and getattr(v, '__declared_attr__', False):
v = updates[k] = v(sub_cls)
if isinstance(v, sa.Column):
v = v.copy()
v.name = k
if not v.name:
v.name = k
column_name_map[k] = v.name
columns.append(v)
updates[k] = sub_cls.__attr_factory__(v)
updates[k] = sub_cls.__attr_factory__(k, v)
elif isinstance(v, (sa.Index, sa.Constraint)):
inspected_args.append(v)
sub_cls._column_name_map = column_name_map

# handle __table_args__
table_args = updates.get('__table_args__',
Expand Down Expand Up @@ -173,4 +202,5 @@ def inspect_model_type(target):
return sa.inspection.inspect(target.__table__)


__all__ = ['ColumnAttribute', 'Model', 'declarative_base', 'declared_attr']
__all__ = ['ColumnAttribute', 'Model', 'declarative_base', 'declared_attr',
'InvertDict']
2 changes: 1 addition & 1 deletion gino/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from contextvars import ContextVar
else:
# noinspection PyPackageRequirements
from aiocontextvars import ContextVar
from aiocontextvars import ContextVar # pragma: no cover


class _BaseDBAPIConnection:
Expand Down
14 changes: 7 additions & 7 deletions gino/json_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,18 @@ def call(self, instance, val):


class JSONProperty:
def __init__(self, default=None, column_name='profile'):
def __init__(self, default=None, prop_name='profile'):
self.name = None
self.default = default
self.column_name = column_name
self.prop_name = prop_name
self.expression = Hook(self)
self.after_get = Hook(self)
self.before_set = Hook(self)

def __get__(self, instance, owner):
if instance is None:
exp = self.make_expression(
getattr(owner, self.column_name)[self.name])
getattr(owner, self.prop_name)[self.name])
return self.expression.call(owner, exp)
val = self.get_profile(instance).get(self.name, NONE)
if val is NONE:
Expand All @@ -54,16 +54,16 @@ def get_profile(self, instance):
if instance.__profile__ is None:
props = type(instance).__dict__
instance.__profile__ = {}
for key, value in (getattr(instance, self.column_name, None)
for key, value in (getattr(instance, self.prop_name, None)
or {}).items():
instance.__profile__[key] = props[key].decode(value)
return instance.__profile__

def save(self, instance, value=NONE):
profile = getattr(instance, self.column_name, None)
profile = getattr(instance, self.prop_name, None)
if profile is None:
profile = {}
setattr(instance, self.column_name, profile)
setattr(instance, self.prop_name, profile)
if value is NONE:
value = instance.__profile__[self.name]
if not isinstance(value, sa.sql.ClauseElement):
Expand All @@ -74,7 +74,7 @@ def save(self, instance, value=NONE):
def reload(self, instance):
if instance.__profile__ is None:
return
profile = getattr(instance, self.column_name, None) or {}
profile = getattr(instance, self.prop_name, None) or {}
value = profile.get(self.name, NONE)
if value is NONE:
instance.__profile__.pop(self.name, None)
Expand Down
6 changes: 5 additions & 1 deletion gino/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ def _do_load(self, row, *, none_as_none=None):
if none_as_none and all((v is None) for v in values.values()):
return None
rv = self.model()
rv.__values__.update(values)
for c in self.columns:
if c in row:
# noinspection PyProtectedMember
instance_key = self.model._column_name_map.invert_get(c.name)
rv.__values__[instance_key] = row[c]
return rv

def do_load(self, row, context):
Expand Down
6 changes: 3 additions & 3 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ class User(db.Model):
__tablename__ = 'gino_users'

id = db.Column(db.BigInteger(), primary_key=True)
nickname = db.Column(db.Unicode(), default='noname')
profile = db.Column(JSONB(), nullable=False, server_default='{}')
nickname = db.Column('name', db.Unicode(), default='noname')
profile = db.Column('props', JSONB(), nullable=False, server_default='{}')
type = db.Column(
db.Enum(UserType),
nullable=False,
default=UserType.USER,
)
name = db.StringProperty()
realname = db.StringProperty()
age = db.IntegerProperty(default=18)
balance = db.IntegerProperty(default=0)
birthday = db.DateTimeProperty(
Expand Down
5 changes: 2 additions & 3 deletions tests/test_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class User(db.Model):
__tablename__ = 'gino_users'

id = db.Column(db.BigInteger(), primary_key=True)
nickname = db.Column(db.Unicode(), default='noname')
nickname = db.Column('name', db.Unicode(), default='noname')

routes = web.RouteTableDef()

Expand Down Expand Up @@ -107,8 +107,7 @@ async def _test(test_client):
response = await test_client.get('/users/1?method=' + method)
assert response.status == 404

response = await test_client.post('/users',
data=dict(name='fantix'))
response = await test_client.post('/users', data=dict(name='fantix'))
assert response.status == 200
assert await response.json() == dict(id=1, nickname='fantix')

Expand Down
2 changes: 1 addition & 1 deletion tests/test_bind.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async def test_unbind(asyncpg_pool):

async def test_db_api(bind, random_name):
assert await db.scalar(
User.insert().values(nickname=random_name).returning(
User.insert().values(name=random_name).returning(
User.nickname)) == random_name
assert (await db.first(User.query.where(
User.nickname == random_name))).nickname == random_name
Expand Down
Loading

0 comments on commit 4588e84

Please sign in to comment.