Skip to content
This repository has been archived by the owner on Dec 21, 2022. It is now read-only.

Commit

Permalink
Merge pull request #38 from brandicted/develop
Browse files Browse the repository at this point in the history
release 0.2.2
  • Loading branch information
chartpath committed May 27, 2015
2 parents df5cf99 + cba3123 commit fee403a
Show file tree
Hide file tree
Showing 8 changed files with 424 additions and 166 deletions.
5 changes: 5 additions & 0 deletions docs/source/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Changelog
=========

* :release:`0.2.2 <2015-05-27>`
* :bug:`-` fixes login issue
* :bug:`-` fixes posting to singular resources e.g. /api/users/<username>/profile
* :bug:`-` fixes multiple foreign keys to same model

* :release:`0.2.1 <2015-05-20>`
* :bug:`-` Fixed slow queries to backrefs

Expand Down
2 changes: 1 addition & 1 deletion nefertari_sqla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from .documents import (
BaseDocument, ESBaseDocument, BaseMixin,
get_document_cls)
get_document_cls, get_document_classes)
from .serializers import JSONEncoder, ESJSONSerializer
from .signals import ESMetaclass
from .utils import (
Expand Down
181 changes: 157 additions & 24 deletions nefertari_sqla/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import logging
from datetime import datetime

from sqlalchemy.orm import class_mapper, object_session, properties
from sqlalchemy.orm import (
class_mapper, object_session, properties, attributes)
from sqlalchemy.orm.collections import InstrumentedList
from sqlalchemy.exc import InvalidRequestError, IntegrityError
from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound
from sqlalchemy.orm.properties import RelationshipProperty
from pyramid_sqlalchemy import Session, BaseObject

from nefertari.json_httpexceptions import (
Expand All @@ -14,7 +16,9 @@
process_fields, process_limit, _split, dictset,
DataProxy)
from .signals import ESMetaclass
from .fields import DateTimeField, IntegerField, DictField, ListField
from .fields import ListField, DictField, DateTimeField, IntegerField
from . import types


log = logging.getLogger(__name__)

Expand All @@ -26,6 +30,22 @@ def get_document_cls(name):
raise ValueError('SQLAlchemy model `{}` does not exist'.format(name))


def get_document_classes():
""" Get all defined not abstract document classes
Class is assumed to be non-abstract if it has `__table__` or
`__tablename__` attributes defined.
"""
document_classes = {}
registry = BaseObject._decl_class_registry
for model_name, model_cls in registry.items():
tablename = (getattr(model_cls, '__table__', None) is not None or
getattr(model_cls, '__tablename__', None) is not None)
if tablename:
document_classes[model_name] = model_cls
return document_classes


def process_lists(_dict):
for k in _dict:
new_k, _, _t = k.partition('__')
Expand All @@ -42,6 +62,31 @@ def process_bools(_dict):
return _dict


TYPES_MAP = {
types.LimitedString: {'type': 'string'},
types.LimitedText: {'type': 'string'},
types.LimitedUnicode: {'type': 'string'},
types.LimitedUnicodeText: {'type': 'string'},
types.ProcessableChoice: {'type': 'string'},

types.ProcessableBoolean: {'type': 'boolean'},
types.ProcessableLargeBinary: {'type': 'object'},
types.ProcessableDict: {'type': 'object'},

types.LimitedNumeric: {'type': 'double'},
types.LimitedFloat: {'type': 'double'},

types.LimitedInteger: {'type': 'long'},
types.LimitedBigInteger: {'type': 'long'},
types.LimitedSmallInteger: {'type': 'long'},
types.ProcessableInterval: {'type': 'long'},

types.ProcessableDateTime: {'type': 'date', 'format': 'dateOptionalTime'},
types.ProcessableDate: {'type': 'date', 'format': 'dateOptionalTime'},
types.ProcessableTime: {'type': 'date', 'format': 'HH:mm:ss'},
}


class BaseMixin(object):
""" Represents mixin class for models.
Expand All @@ -62,6 +107,33 @@ class BaseMixin(object):

_type = property(lambda self: self.__class__.__name__)

@classmethod
def get_es_mapping(cls):
""" Generate ES mapping from model schema. """
from nefertari.elasticsearch import ES
properties = {}
mapping = {
ES.src2type(cls.__name__): {
'properties': properties
}
}
mapper = class_mapper(cls)
columns = {c.name: c for c in mapper.columns}
# Replace field 'id' with primary key field
columns['id'] = columns.get(cls.pk_field())

for name, column in columns.items():
column_type = column.type
if isinstance(column_type, types.ProcessableChoiceArray):
column_type = column_type.impl.item_type
column_type = type(column_type)
if column_type not in TYPES_MAP:
continue
properties[name] = TYPES_MAP[column_type]

properties['_type'] = {'type': 'string'}
return mapping

@classmethod
def autogenerate_for(cls, model, set_to):
""" Setup `after_insert` event handler.
Expand Down Expand Up @@ -192,7 +264,6 @@ def _pop_iterables(cls, params):
If ListField uses the `postgresql.ARRAY` type, the value is
wrapped in a list.
"""
from .fields import ListField, DictField
iterables = {}
columns = class_mapper(cls).columns
columns = {c.name: c for c in columns
Expand Down Expand Up @@ -377,24 +448,20 @@ def get_or_create(cls, **params):
def _update(self, params, **kw):
process_bools(params)
self.check_fields_allowed(params.keys())
fields = {c.name: c for c in class_mapper(self.__class__).columns}
iter_fields = set(
k for k, v in fields.items()
columns = {c.name: c for c in class_mapper(self.__class__).columns}
iter_columns = set(
k for k, v in columns.items()
if isinstance(v, (DictField, ListField)))
pk_field = self.pk_field()

for key, new_value in params.items():
# Can't change PK field
if key == pk_field:
continue
if key in iter_fields:
if key in iter_columns:
self.update_iterables(new_value, key, unique=True, save=False)
else:
setattr(self, key, new_value)

session = object_session(self)
session.add(self)
session.flush()
return self

@classmethod
Expand Down Expand Up @@ -431,6 +498,21 @@ def get_by_ids(cls, ids, **params):
cls_id = getattr(cls, cls.pk_field())
return query_set.from_self().filter(cls_id.in_(ids)).limit(len(ids))

@classmethod
def get_null_values(cls):
""" Get null values of :cls: fields. """
null_values = {}
mapper = class_mapper(cls)
columns = {c.name: c for c in mapper.columns}
columns.update({r.key: r for r in mapper.relationships})
for name, col in columns.items():
if isinstance(col, RelationshipProperty) and col.uselist:
value = []
else:
value = None
null_values[name] = value
return null_values

def to_dict(self, **kwargs):
native_fields = self.__class__.native_fields()
_data = {}
Expand All @@ -452,9 +534,9 @@ def to_dict(self, **kwargs):
def update_iterables(self, params, attr, unique=False,
value_type=None, save=True):
mapper = class_mapper(self.__class__)
fields = {c.name: c for c in mapper.columns}
is_dict = isinstance(fields.get(attr), DictField)
is_list = isinstance(fields.get(attr), ListField)
columns = {c.name: c for c in mapper.columns}
is_dict = isinstance(columns.get(attr), DictField)
is_list = isinstance(columns.get(attr), ListField)

def split_keys(keys):
neg_keys, pos_keys = [], []
Expand All @@ -468,29 +550,35 @@ def split_keys(keys):
pos_keys.append(key.strip())
return pos_keys, neg_keys

def update_dict():
def update_dict(update_params):
final_value = getattr(self, attr, {}) or {}
final_value = final_value.copy()
positive, negative = split_keys(params.keys())
if update_params is None:
update_params = {
'-' + key: val for key, val in final_value.items()}
positive, negative = split_keys(update_params.keys())

# Pop negative keys
for key in negative:
final_value.pop(key, None)

# Set positive keys
for key in positive:
final_value[unicode(key)] = params[key]
final_value[unicode(key)] = update_params[key]

setattr(self, attr, final_value)
if save:
session = object_session(self)
session.add(self)
session.flush()

def update_list():
def update_list(update_params):
final_value = getattr(self, attr, []) or []
final_value = copy.deepcopy(final_value)
keys = params.keys() if isinstance(params, dict) else params
if update_params is None:
update_params = ['-' + val for val in final_value]
keys = (update_params.keys() if isinstance(update_params, dict)
else update_params)
positive, negative = split_keys(keys)

if not (positive + negative):
Expand All @@ -511,9 +599,9 @@ def update_list():
session.flush()

if is_dict:
update_dict()
update_dict(params)
elif is_list:
update_list()
update_list(params)

def get_reference_documents(self):
# TODO: Make lazy load of documents
Expand All @@ -531,6 +619,21 @@ def get_reference_documents(self):
session.refresh(value)
yield (value.__class__, [value.to_dict()])

def _is_modified(self):
""" Determine if instance is modified.
For instance to be marked as 'modified', it should:
* Have state marked as modified
* Have state marked as persistent
* Any of modified fields have new value
"""
state = attributes.instance_state(self)
if state.persistent and state.modified:
for field in state.committed_state.keys():
history = state.get_history(field, self)
if history.added or history.deleted:
return True


class BaseDocument(BaseObject, BaseMixin):
""" Base class for SQLA models.
Expand All @@ -544,7 +647,7 @@ class BaseDocument(BaseObject, BaseMixin):
_version = IntegerField(default=0)

def _bump_version(self):
if getattr(self, self.pk_field(), None):
if self._is_modified():
self.updated_at = datetime.utcnow()
self._version = (self._version or 0) + 1

Expand All @@ -553,6 +656,7 @@ def save(self, *arg, **kw):
self._bump_version()
session = session or Session()
try:
self.clean()
session.add(self)
session.flush()
session.expire(self)
Expand All @@ -567,9 +671,14 @@ def save(self, *arg, **kw):
extra={'data': e})

def update(self, params):
self._bump_version()
try:
return self._update(params)
self._update(params)
self._bump_version()
self.clean()
session = object_session(self)
session.add(self)
session.flush()
return self
except (IntegrityError,) as e:
if 'duplicate' not in e.message:
raise # other error, not duplicate
Expand All @@ -579,6 +688,30 @@ def update(self, params):
self.__class__.__name__),
extra={'data': e})

def clean(self, force_all=False):
""" Apply field processors to all changed fields And perform custom
field values cleaning before running DB validation.
Note that at this stage, field values are in the exact same state
you posted/set them. E.g. if you set time_field='11/22/2000',
self.time_field will be equal to '11/22/2000' here.
"""
columns = {c.key: c for c in class_mapper(self.__class__).columns}
state = attributes.instance_state(self)

if state.persistent and not force_all:
changed_columns = state.committed_state.keys()
else: # New object
changed_columns = columns.keys()

for name in changed_columns:
column = columns.get(name)
if column is not None and hasattr(column, 'apply_processors'):
new_value = getattr(self, name)
processed_value = column.apply_processors(
instance=self, new_value=new_value)
setattr(self, name, processed_value)


class ESBaseDocument(BaseDocument):
""" Base class for SQLA models that use Elasticsearch.
Expand Down
Loading

0 comments on commit fee403a

Please sign in to comment.