diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index b0266ca..be8a4a9 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,11 @@ Changelog ========= +* :release:`0.2.2 <2015-05-27>` +* :bug:`-` fixes login issue +* :bug:`-` fixes posting to singular resources e.g. /api/users//profile +* :bug:`-` fixes multiple foreign keys to same model + * :release:`0.2.1 <2015-05-20>` * :bug:`-` Fixed slow queries to backrefs diff --git a/nefertari_sqla/__init__.py b/nefertari_sqla/__init__.py index 9adfd26..305fe7e 100644 --- a/nefertari_sqla/__init__.py +++ b/nefertari_sqla/__init__.py @@ -6,7 +6,7 @@ from .documents import ( BaseDocument, ESBaseDocument, BaseMixin, - get_document_cls) + get_document_cls, get_document_classes) from .serializers import JSONEncoder, ESJSONSerializer from .signals import ESMetaclass from .utils import ( diff --git a/nefertari_sqla/documents.py b/nefertari_sqla/documents.py index 571a495..53254ce 100644 --- a/nefertari_sqla/documents.py +++ b/nefertari_sqla/documents.py @@ -2,10 +2,12 @@ import logging from datetime import datetime -from sqlalchemy.orm import class_mapper, object_session, properties +from sqlalchemy.orm import ( + class_mapper, object_session, properties, attributes) from sqlalchemy.orm.collections import InstrumentedList from sqlalchemy.exc import InvalidRequestError, IntegrityError from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound +from sqlalchemy.orm.properties import RelationshipProperty from pyramid_sqlalchemy import Session, BaseObject from nefertari.json_httpexceptions import ( @@ -14,7 +16,9 @@ process_fields, process_limit, _split, dictset, DataProxy) from .signals import ESMetaclass -from .fields import DateTimeField, IntegerField, DictField, ListField +from .fields import ListField, DictField, DateTimeField, IntegerField +from . import types + log = logging.getLogger(__name__) @@ -26,6 +30,22 @@ def get_document_cls(name): raise ValueError('SQLAlchemy model `{}` does not exist'.format(name)) +def get_document_classes(): + """ Get all defined not abstract document classes + + Class is assumed to be non-abstract if it has `__table__` or + `__tablename__` attributes defined. + """ + document_classes = {} + registry = BaseObject._decl_class_registry + for model_name, model_cls in registry.items(): + tablename = (getattr(model_cls, '__table__', None) is not None or + getattr(model_cls, '__tablename__', None) is not None) + if tablename: + document_classes[model_name] = model_cls + return document_classes + + def process_lists(_dict): for k in _dict: new_k, _, _t = k.partition('__') @@ -42,6 +62,31 @@ def process_bools(_dict): return _dict +TYPES_MAP = { + types.LimitedString: {'type': 'string'}, + types.LimitedText: {'type': 'string'}, + types.LimitedUnicode: {'type': 'string'}, + types.LimitedUnicodeText: {'type': 'string'}, + types.ProcessableChoice: {'type': 'string'}, + + types.ProcessableBoolean: {'type': 'boolean'}, + types.ProcessableLargeBinary: {'type': 'object'}, + types.ProcessableDict: {'type': 'object'}, + + types.LimitedNumeric: {'type': 'double'}, + types.LimitedFloat: {'type': 'double'}, + + types.LimitedInteger: {'type': 'long'}, + types.LimitedBigInteger: {'type': 'long'}, + types.LimitedSmallInteger: {'type': 'long'}, + types.ProcessableInterval: {'type': 'long'}, + + types.ProcessableDateTime: {'type': 'date', 'format': 'dateOptionalTime'}, + types.ProcessableDate: {'type': 'date', 'format': 'dateOptionalTime'}, + types.ProcessableTime: {'type': 'date', 'format': 'HH:mm:ss'}, +} + + class BaseMixin(object): """ Represents mixin class for models. @@ -62,6 +107,33 @@ class BaseMixin(object): _type = property(lambda self: self.__class__.__name__) + @classmethod + def get_es_mapping(cls): + """ Generate ES mapping from model schema. """ + from nefertari.elasticsearch import ES + properties = {} + mapping = { + ES.src2type(cls.__name__): { + 'properties': properties + } + } + mapper = class_mapper(cls) + columns = {c.name: c for c in mapper.columns} + # Replace field 'id' with primary key field + columns['id'] = columns.get(cls.pk_field()) + + for name, column in columns.items(): + column_type = column.type + if isinstance(column_type, types.ProcessableChoiceArray): + column_type = column_type.impl.item_type + column_type = type(column_type) + if column_type not in TYPES_MAP: + continue + properties[name] = TYPES_MAP[column_type] + + properties['_type'] = {'type': 'string'} + return mapping + @classmethod def autogenerate_for(cls, model, set_to): """ Setup `after_insert` event handler. @@ -192,7 +264,6 @@ def _pop_iterables(cls, params): If ListField uses the `postgresql.ARRAY` type, the value is wrapped in a list. """ - from .fields import ListField, DictField iterables = {} columns = class_mapper(cls).columns columns = {c.name: c for c in columns @@ -377,9 +448,9 @@ def get_or_create(cls, **params): def _update(self, params, **kw): process_bools(params) self.check_fields_allowed(params.keys()) - fields = {c.name: c for c in class_mapper(self.__class__).columns} - iter_fields = set( - k for k, v in fields.items() + columns = {c.name: c for c in class_mapper(self.__class__).columns} + iter_columns = set( + k for k, v in columns.items() if isinstance(v, (DictField, ListField))) pk_field = self.pk_field() @@ -387,14 +458,10 @@ def _update(self, params, **kw): # Can't change PK field if key == pk_field: continue - if key in iter_fields: + if key in iter_columns: self.update_iterables(new_value, key, unique=True, save=False) else: setattr(self, key, new_value) - - session = object_session(self) - session.add(self) - session.flush() return self @classmethod @@ -431,6 +498,21 @@ def get_by_ids(cls, ids, **params): cls_id = getattr(cls, cls.pk_field()) return query_set.from_self().filter(cls_id.in_(ids)).limit(len(ids)) + @classmethod + def get_null_values(cls): + """ Get null values of :cls: fields. """ + null_values = {} + mapper = class_mapper(cls) + columns = {c.name: c for c in mapper.columns} + columns.update({r.key: r for r in mapper.relationships}) + for name, col in columns.items(): + if isinstance(col, RelationshipProperty) and col.uselist: + value = [] + else: + value = None + null_values[name] = value + return null_values + def to_dict(self, **kwargs): native_fields = self.__class__.native_fields() _data = {} @@ -452,9 +534,9 @@ def to_dict(self, **kwargs): def update_iterables(self, params, attr, unique=False, value_type=None, save=True): mapper = class_mapper(self.__class__) - fields = {c.name: c for c in mapper.columns} - is_dict = isinstance(fields.get(attr), DictField) - is_list = isinstance(fields.get(attr), ListField) + columns = {c.name: c for c in mapper.columns} + is_dict = isinstance(columns.get(attr), DictField) + is_list = isinstance(columns.get(attr), ListField) def split_keys(keys): neg_keys, pos_keys = [], [] @@ -468,10 +550,13 @@ def split_keys(keys): pos_keys.append(key.strip()) return pos_keys, neg_keys - def update_dict(): + def update_dict(update_params): final_value = getattr(self, attr, {}) or {} final_value = final_value.copy() - positive, negative = split_keys(params.keys()) + if update_params is None: + update_params = { + '-' + key: val for key, val in final_value.items()} + positive, negative = split_keys(update_params.keys()) # Pop negative keys for key in negative: @@ -479,7 +564,7 @@ def update_dict(): # Set positive keys for key in positive: - final_value[unicode(key)] = params[key] + final_value[unicode(key)] = update_params[key] setattr(self, attr, final_value) if save: @@ -487,10 +572,13 @@ def update_dict(): session.add(self) session.flush() - def update_list(): + def update_list(update_params): final_value = getattr(self, attr, []) or [] final_value = copy.deepcopy(final_value) - keys = params.keys() if isinstance(params, dict) else params + if update_params is None: + update_params = ['-' + val for val in final_value] + keys = (update_params.keys() if isinstance(update_params, dict) + else update_params) positive, negative = split_keys(keys) if not (positive + negative): @@ -511,9 +599,9 @@ def update_list(): session.flush() if is_dict: - update_dict() + update_dict(params) elif is_list: - update_list() + update_list(params) def get_reference_documents(self): # TODO: Make lazy load of documents @@ -531,6 +619,21 @@ def get_reference_documents(self): session.refresh(value) yield (value.__class__, [value.to_dict()]) + def _is_modified(self): + """ Determine if instance is modified. + + For instance to be marked as 'modified', it should: + * Have state marked as modified + * Have state marked as persistent + * Any of modified fields have new value + """ + state = attributes.instance_state(self) + if state.persistent and state.modified: + for field in state.committed_state.keys(): + history = state.get_history(field, self) + if history.added or history.deleted: + return True + class BaseDocument(BaseObject, BaseMixin): """ Base class for SQLA models. @@ -544,7 +647,7 @@ class BaseDocument(BaseObject, BaseMixin): _version = IntegerField(default=0) def _bump_version(self): - if getattr(self, self.pk_field(), None): + if self._is_modified(): self.updated_at = datetime.utcnow() self._version = (self._version or 0) + 1 @@ -553,6 +656,7 @@ def save(self, *arg, **kw): self._bump_version() session = session or Session() try: + self.clean() session.add(self) session.flush() session.expire(self) @@ -567,9 +671,14 @@ def save(self, *arg, **kw): extra={'data': e}) def update(self, params): - self._bump_version() try: - return self._update(params) + self._update(params) + self._bump_version() + self.clean() + session = object_session(self) + session.add(self) + session.flush() + return self except (IntegrityError,) as e: if 'duplicate' not in e.message: raise # other error, not duplicate @@ -579,6 +688,30 @@ def update(self, params): self.__class__.__name__), extra={'data': e}) + def clean(self, force_all=False): + """ Apply field processors to all changed fields And perform custom + field values cleaning before running DB validation. + + Note that at this stage, field values are in the exact same state + you posted/set them. E.g. if you set time_field='11/22/2000', + self.time_field will be equal to '11/22/2000' here. + """ + columns = {c.key: c for c in class_mapper(self.__class__).columns} + state = attributes.instance_state(self) + + if state.persistent and not force_all: + changed_columns = state.committed_state.keys() + else: # New object + changed_columns = columns.keys() + + for name in changed_columns: + column = columns.get(name) + if column is not None and hasattr(column, 'apply_processors'): + new_value = getattr(self, name) + processed_value = column.apply_processors( + instance=self, new_value=new_value) + setattr(self, name, processed_value) + class ESBaseDocument(BaseDocument): """ Base class for SQLA models that use Elasticsearch. diff --git a/nefertari_sqla/fields.py b/nefertari_sqla/fields.py index 0dfcd25..5a6342d 100644 --- a/nefertari_sqla/fields.py +++ b/nefertari_sqla/fields.py @@ -25,6 +25,20 @@ ) +class ProcessableMixin(object): + """ Mixin that allows running callables on a value that + is being set on a field. + """ + def __init__(self, *args, **kwargs): + self.processors = kwargs.pop('processors', ()) + super(ProcessableMixin, self).__init__(*args, **kwargs) + + def apply_processors(self, instance, new_value): + for proc in self.processors: + new_value = proc(instance=instance, new_value=new_value) + return new_value + + class BaseField(Column): """ Base plain column that otherwise would be created as sqlalchemy.Column(sqlalchemy.Type()) @@ -61,7 +75,17 @@ def __init__(self, *args, **kwargs): # Column init when defining a schema else: col_kw['type_'] = self._sqla_type(*type_args, **type_kw) - return super(BaseField, self).__init__(**col_kw) + super(BaseField, self).__init__(**col_kw) + + def __setattr__(self, key, value): + """ Store column name on 'self.type' + + This allows error messages in custom types' validation be more + explicit. + """ + if value is not None and key == 'name': + self.type._column_name = value + return super(BaseField, self).__setattr__(key, value) def process_type_args(self, kwargs): """ Process arguments of a sqla Type. @@ -112,15 +136,14 @@ def process_column_args(self, kwargs): def _constructor(self): return self.__class__ - -class BigIntegerField(BaseField): +class BigIntegerField(ProcessableMixin, BaseField): _sqla_type = LimitedBigInteger - _type_unchanged_kwargs = ('min_value', 'max_value', 'processors') + _type_unchanged_kwargs = ('min_value', 'max_value') -class BooleanField(BaseField): +class BooleanField(ProcessableMixin, BaseField): _sqla_type = ProcessableBoolean - _type_unchanged_kwargs = ('create_constraint', 'processors') + _type_unchanged_kwargs = ('create_constraint') def process_type_args(self, kwargs): """ @@ -135,33 +158,33 @@ def process_type_args(self, kwargs): return type_args, type_kw, cleaned_kw -class DateField(BaseField): +class DateField(ProcessableMixin, BaseField): _sqla_type = ProcessableDate - _type_unchanged_kwargs = ('processors',) + _type_unchanged_kwargs = () -class DateTimeField(BaseField): +class DateTimeField(ProcessableMixin, BaseField): _sqla_type = ProcessableDateTime - _type_unchanged_kwargs = ('timezone', 'processors') + _type_unchanged_kwargs = ('timezone',) -class ChoiceField(BaseField): +class ChoiceField(ProcessableMixin, BaseField): _sqla_type = ProcessableChoice _type_unchanged_kwargs = ( 'collation', 'convert_unicode', 'unicode_error', - '_warn_on_bytestring', 'choices', 'processors') + '_warn_on_bytestring', 'choices') -class FloatField(BaseField): +class FloatField(ProcessableMixin, BaseField): _sqla_type = LimitedFloat _type_unchanged_kwargs = ( 'precision', 'asdecimal', 'decimal_return_scale', - 'min_value', 'max_value', 'processors') + 'min_value', 'max_value') -class IntegerField(BaseField): +class IntegerField(ProcessableMixin, BaseField): _sqla_type = LimitedInteger - _type_unchanged_kwargs = ('min_value', 'max_value', 'processors') + _type_unchanged_kwargs = ('min_value', 'max_value') class IdField(IntegerField): @@ -171,46 +194,44 @@ class IdField(IntegerField): pass -class IntervalField(BaseField): +class IntervalField(ProcessableMixin, BaseField): _sqla_type = ProcessableInterval _type_unchanged_kwargs = ( - 'native', 'second_precision', 'day_precision', 'processors') + 'native', 'second_precision', 'day_precision') -class BinaryField(BaseField): +class BinaryField(ProcessableMixin, BaseField): _sqla_type = ProcessableLargeBinary - _type_unchanged_kwargs = ('length', 'processors') + _type_unchanged_kwargs = ('length',) # Since SQLAlchemy 1.0.0 # class MatchField(BooleanField): # _sqla_type = MatchType -class DecimalField(BaseField): +class DecimalField(ProcessableMixin, BaseField): _sqla_type = LimitedNumeric _type_unchanged_kwargs = ( 'precision', 'scale', 'decimal_return_scale', 'asdecimal', - 'min_value', 'max_value', 'processors') + 'min_value', 'max_value') -class PickleField(BaseField): +class PickleField(ProcessableMixin, BaseField): _sqla_type = ProcessablePickleType _type_unchanged_kwargs = ( - 'protocol', 'pickler', 'comparator', - 'processors') + 'protocol', 'pickler', 'comparator') -class SmallIntegerField(BaseField): +class SmallIntegerField(ProcessableMixin, BaseField): _sqla_type = LimitedSmallInteger - _type_unchanged_kwargs = ('min_value', 'max_value', 'processors') + _type_unchanged_kwargs = ('min_value', 'max_value') -class StringField(BaseField): +class StringField(ProcessableMixin, BaseField): _sqla_type = LimitedString _type_unchanged_kwargs = ( 'collation', 'convert_unicode', 'unicode_error', - '_warn_on_bytestring', 'min_length', 'max_length', - 'processors') + '_warn_on_bytestring', 'min_length', 'max_length') def process_type_args(self, kwargs): """ @@ -454,13 +475,17 @@ def Relationship(**kwargs): simple many-to-one references. """ backref_pre = 'backref_' - kwargs['doc'] = kwargs.pop('help_text', None) - kwargs[backref_pre + 'doc'] = kwargs.pop( - backref_pre + 'help_text', None) + if 'help_text' in kwargs: + kwargs['doc'] = kwargs.pop('help_text', None) + if (backref_pre + 'help_text') in kwargs: + kwargs[backref_pre + 'doc'] = kwargs.pop( + backref_pre + 'help_text', None) + kwargs = {k: v for k, v in kwargs.items() if k in relationship_kwargs or k[len(backref_pre):] in relationship_kwargs} rel_kw, backref_kw = {}, {} + for key, val in kwargs.items(): if key.startswith(backref_pre): key = key[len(backref_pre):] diff --git a/nefertari_sqla/tests/test_documents.py b/nefertari_sqla/tests/test_documents.py index 4f77973..32f75eb 100644 --- a/nefertari_sqla/tests/test_documents.py +++ b/nefertari_sqla/tests/test_documents.py @@ -21,6 +21,21 @@ def test_get_document_cls(self, mock_obj): doc_cls = docs.get_document_cls('foo') assert doc_cls == 'bar' + @patch.object(docs, 'BaseObject') + def test_get_document_classes(self, mock_obj): + foo_mock = Mock(__table__='foo') + baz_mock = Mock(__tablename__='baz') + mock_obj._decl_class_registry = { + 'Foo': foo_mock, + 'Bar': Mock(__table__=None), + 'Baz': baz_mock, + } + document_classes = docs.get_document_classes() + assert document_classes == { + 'Foo': foo_mock, + 'Baz': baz_mock, + } + @patch.object(docs, 'BaseObject') def test_get_document_cls_key_error(self, mock_obj): mock_obj._decl_class_registry = {} @@ -56,6 +71,32 @@ def test_process_bools(self): class TestBaseMixin(object): + def test_get_es_mapping(self, memory_db): + class MyModel(docs.BaseDocument): + __tablename__ = 'mymodel' + my_id = fields.IdField() + name = fields.StringField(primary_key=True) + groups = fields.ListField( + item_type=fields.StringField, + choices=['admin', 'user']) + memory_db() + + mapping = MyModel.get_es_mapping() + assert mapping == { + 'mymodel': { + 'properties': { + '_type': {'type': 'string'}, + '_version': {'type': 'long'}, + 'groups': {'type': 'string'}, + 'id': {'type': 'string'}, + 'my_id': {'type': 'long'}, + 'name': {'type': 'string'}, + 'updated_at': {'format': 'dateOptionalTime', + 'type': 'date'} + } + } + } + def test_pk_field(self, memory_db): class MyModel(docs.BaseDocument): __tablename__ = 'mymodel' @@ -302,8 +343,7 @@ def test_get_or_create_existing_created( assert one.id == 7 assert one.name == 'q' - @patch.object(docs, 'object_session') - def test_underscore_update(self, obj_session, memory_db): + def test_underscore_update(self, memory_db): class MyModel(docs.BaseDocument): __tablename__ = 'mymodel' id = fields.IdField(primary_key=True) @@ -314,9 +354,6 @@ class MyModel(docs.BaseDocument): myobj = MyModel(id=4, name='foo') newobj = myobj._update( {'id': 5, 'name': 'bar', 'settings': {'sett1': 'val1'}}) - obj_session.assert_called_once_with(myobj) - obj_session().add.assert_called_once_with(myobj) - obj_session().flush.assert_called_once_with() assert newobj.id == 4 assert newobj.name == 'bar' assert newobj.settings == {'sett1': 'val1'} @@ -352,7 +389,7 @@ def test_repr(self): def test_get_by_ids(self, mock_coll, memory_db): class MyModel(docs.BaseDocument): __tablename__ = 'mymodel' - name = fields.IdField(primary_key=True) + name = fields.StringField(primary_key=True) memory_db() MyModel.name = Mock() MyModel.get_by_ids([1, 2, 3], foo='bar') @@ -361,6 +398,35 @@ class MyModel(docs.BaseDocument): assert mock_coll().from_self().filter.call_count == 1 mock_coll().from_self().filter().limit.assert_called_once_with(3) + def test_get_null_values(self, memory_db): + class MyModel1(docs.BaseDocument): + __tablename__ = 'mymodel1' + name = fields.StringField(primary_key=True) + fk_field = fields.ForeignKeyField( + ref_document='MyModel2', ref_column='mymodel2.name', + ref_column_type=fields.StringField) + + class MyModel2(docs.BaseDocument): + __tablename__ = 'mymodel2' + name = fields.StringField(primary_key=True) + models1 = fields.Relationship( + document='MyModel1', backref_name='model2') + + assert MyModel1.get_null_values() == { + '_version': None, + 'fk_field': None, + 'name': None, + 'model2': None, + 'updated_at': None, + } + + assert MyModel2.get_null_values() == { + '_version': None, + 'models1': [], + 'name': None, + 'updated_at': None, + } + def test_to_dict(self, memory_db): class MyModel(docs.BaseDocument): __tablename__ = 'mymodel' @@ -483,6 +549,30 @@ class Parent(docs.BaseDocument): result = [v for v in parent.get_reference_documents()] assert len(result) == 0 + def test_is_modified_id_not_persistent(self, memory_db, simple_model): + memory_db() + obj = simple_model() + assert not obj._is_modified() + + def test_is_modified_no_modified_fields(self, memory_db, simple_model): + memory_db() + obj = simple_model(id=1).save() + assert not obj._is_modified() + + def test_is_modified_same_value_set(self, memory_db, simple_model): + memory_db() + obj = simple_model(id=1, name='foo').save() + obj = simple_model.get(id=1) + obj.name = 'foo' + assert not obj._is_modified() + + def test_is_modified(self, memory_db, simple_model): + memory_db() + obj = simple_model(id=1, name='foo').save() + obj = simple_model.get(id=1) + obj.name = 'bar' + assert obj._is_modified() + class TestBaseDocument(object): @@ -494,10 +584,9 @@ def test_bump_version(self, simple_model, memory_db): assert myobj._version is None assert myobj.updated_at is None myobj._bump_version() - assert myobj._version is None - assert myobj.updated_at is None - myobj.id = 1 + myobj.save() + myobj.name = 'foo' myobj._bump_version() assert myobj._version == 1 assert isinstance(myobj.updated_at, datetime) @@ -509,7 +598,7 @@ def test_save(self, obj_session, simple_model, memory_db): myobj = simple_model(id=4) newobj = myobj.save() assert newobj == myobj - assert myobj._version == 1 + assert myobj._version is None obj_session.assert_called_once_with(myobj) obj_session().add.assert_called_once_with(myobj) obj_session().flush.assert_called_once_with() @@ -526,13 +615,17 @@ def test_save_error(self, obj_session, simple_model, memory_db): simple_model(id=4).save() assert 'There was a conflict' in str(ex.value) + @patch.object(docs, 'object_session') @patch.object(docs.BaseMixin, '_update') - def test_update(self, mock_upd, simple_model, memory_db): + def test_update(self, mock_upd, mock_sess, simple_model, memory_db): memory_db() myobj = simple_model(id=4) myobj.update({'name': 'q'}) mock_upd.assert_called_once_with({'name': 'q'}) + mock_sess.assert_called_once_with(myobj) + mock_sess().add.assert_called_once_with(myobj) + mock_sess().flush.assert_called_once_with() @patch.object(docs.BaseMixin, '_update') def test_update_error(self, mock_upd, simple_model, memory_db): @@ -546,6 +639,41 @@ def test_update_error(self, mock_upd, simple_model, memory_db): simple_model(id=4).update({'name': 'q'}) assert 'There was a conflict' in str(ex.value) + def test_clean_new_object(self, memory_db): + processor = lambda instance, new_value: 'foobar' + + class MyModel(docs.BaseDocument): + __tablename__ = 'mymodel' + id = fields.IdField(primary_key=True) + name = fields.StringField(processors=[processor]) + email = fields.StringField(processors=[processor]) + memory_db() + + obj = MyModel(name='myname') + obj.clean() + assert obj.name == 'foobar' + assert obj.email == 'foobar' + + def test_clean_existing_object(self, memory_db): + processor = lambda instance, new_value: new_value + '-' + + class MyModel(docs.BaseDocument): + __tablename__ = 'mymodel' + id = fields.IdField(primary_key=True) + name = fields.StringField(processors=[processor]) + email = fields.StringField(processors=[processor]) + memory_db() + + obj = MyModel(id=1, name='myname', email='FOO').save() + assert obj.name == 'myname-' + assert obj.email == 'FOO-' + + obj = MyModel.get(id=1) + obj.name = 'supername' + obj.clean() + assert obj.name == 'supername-' + assert obj.email == 'FOO-' + class TestGetCollection(object): diff --git a/nefertari_sqla/tests/test_types.py b/nefertari_sqla/tests/test_types.py index eb79310..95a7fd5 100644 --- a/nefertari_sqla/tests/test_types.py +++ b/nefertari_sqla/tests/test_types.py @@ -14,31 +14,10 @@ def __init__(self, *args, **kwargs): pass -class TestProcessableMixin(object): - - class Processable(types.ProcessableMixin, DemoClass): - pass - - def test_process_bind_param(self): - processors = [ - lambda v: v.lower(), - lambda v: v.strip(), - lambda v: 'Processed ' + v, - ] - mixin = self.Processable(processors=processors) - value = mixin.process_bind_param(' WeIrd ValUE ', None) - assert value == 'Processed weird value' - - def test_process_bind_param_no_processors(self): - mixin = self.Processable() - value = mixin.process_bind_param(' WeIrd ValUE ', None) - assert value == ' WeIrd ValUE ' - - class TestLengthLimitedStringMixin(object): class Limited(types.LengthLimitedStringMixin, DemoClass): - pass + _column_name = 'foo' def test_none_value(self): mixin = self.Limited(min_length=5) @@ -51,7 +30,8 @@ def test_min_length(self): mixin = self.Limited(min_length=5) with pytest.raises(ValueError) as ex: mixin.process_bind_param('q', None) - assert str(ex.value) == 'Value length must be more than 5' + assert str(ex.value) == ( + 'Field `foo`: Value length must be more than 5') try: mixin.process_bind_param('asdasdasd', None) except ValueError: @@ -61,7 +41,8 @@ def test_max_length(self): mixin = self.Limited(max_length=5) with pytest.raises(ValueError) as ex: mixin.process_bind_param('asdasdasdasdasd', None) - assert str(ex.value) == 'Value length must be less than 5' + assert str(ex.value) == ( + 'Field `foo`: Value length must be less than 5') try: mixin.process_bind_param('q', None) except ValueError: @@ -71,10 +52,12 @@ def test_min_and_max_length(self): mixin = self.Limited(max_length=5, min_length=2) with pytest.raises(ValueError) as ex: mixin.process_bind_param('a', None) - assert str(ex.value) == 'Value length must be more than 2' + assert str(ex.value) == ( + 'Field `foo`: Value length must be more than 2') with pytest.raises(ValueError) as ex: mixin.process_bind_param('a12313123123', None) - assert str(ex.value) == 'Value length must be less than 5' + assert str(ex.value) == ( + 'Field `foo`: Value length must be less than 5') try: mixin.process_bind_param('12q', None) except ValueError: @@ -84,7 +67,7 @@ def test_min_and_max_length(self): class TestSizeLimitedNumberMixin(object): class Limited(types.SizeLimitedNumberMixin, DemoClass): - pass + _column_name = 'foo' def test_none_value(self): mixin = self.Limited(min_value=5) @@ -97,7 +80,8 @@ def test_min_value(self): mixin = self.Limited(min_value=5) with pytest.raises(ValueError) as ex: mixin.process_bind_param(1, None) - assert str(ex.value) == 'Value must be bigger than 5' + assert str(ex.value) == ( + 'Field `foo`: Value must be bigger than 5') try: mixin.process_bind_param(10, None) except ValueError: @@ -107,7 +91,8 @@ def test_max_value(self): mixin = self.Limited(max_value=5) with pytest.raises(ValueError) as ex: mixin.process_bind_param(10, None) - assert str(ex.value) == 'Value must be less than 5' + assert str(ex.value) == ( + 'Field `foo`: Value must be less than 5') try: mixin.process_bind_param(1, None) except ValueError: @@ -117,10 +102,12 @@ def test_min_and_max_value(self): mixin = self.Limited(max_value=5, min_value=2) with pytest.raises(ValueError) as ex: mixin.process_bind_param(1, None) - assert str(ex.value) == 'Value must be bigger than 2' + assert str(ex.value) == ( + 'Field `foo`: Value must be bigger than 2') with pytest.raises(ValueError) as ex: mixin.process_bind_param(10, None) - assert str(ex.value) == 'Value must be less than 5' + assert str(ex.value) == ( + 'Field `foo`: Value must be less than 5') try: mixin.process_bind_param(3, None) except ValueError: @@ -131,10 +118,11 @@ class TestProcessableChoice(object): def test_no_choices(self): field = types.ProcessableChoice() + field._column_name = 'foo' with pytest.raises(ValueError) as ex: field.process_bind_param('foo', None) assert str(ex.value) == \ - 'Got an invalid choice `foo`. Valid choices: ()' + 'Field `foo`: Got an invalid choice `foo`. Valid choices: ()' def test_none_value(self): field = types.ProcessableChoice() @@ -145,10 +133,11 @@ def test_none_value(self): def test_value_not_in_choices(self): field = types.ProcessableChoice(choices=['foo']) + field._column_name = 'foo' with pytest.raises(ValueError) as ex: field.process_bind_param('bar', None) assert str(ex.value) == \ - 'Got an invalid choice `bar`. Valid choices: (foo)' + 'Field `foo`: Got an invalid choice `bar`. Valid choices: (foo)' def test_value_in_choices(self): field = types.ProcessableChoice(choices=['foo']) @@ -157,15 +146,6 @@ def test_value_in_choices(self): except ValueError: raise Exception('Unexpected error') - def test_processed_value_in_choices(self): - field = types.ProcessableChoice( - choices=['foo'], - processors=[lambda v: v.lower()]) - try: - field.process_bind_param('FoO', None) - except ValueError: - raise Exception('Unexpected error') - def test_choices_not_sequence(self): field = types.ProcessableChoice(choices='foo') try: @@ -290,10 +270,12 @@ def test_validate_choices_invalid(self): field = types.ProcessableChoiceArray( item_type=fields.StringField, choices=['foo', 'bar']) + field._column_name = 'mycol' with pytest.raises(ValueError) as ex: field._validate_choices(['qoo', 'foo']) assert str(ex.value) == ( - 'Got invalid choices: (qoo). Valid choices: (foo, bar)') + 'Field `mycol`: Got invalid choices: (qoo). ' + 'Valid choices: (foo, bar)') def test_process_bind_param_postgres(self): field = types.ProcessableChoiceArray(item_type=fields.StringField) diff --git a/nefertari_sqla/types.py b/nefertari_sqla/types.py index 4d99c6c..2fa7fed 100644 --- a/nefertari_sqla/types.py +++ b/nefertari_sqla/types.py @@ -5,22 +5,10 @@ from sqlalchemy.dialects.postgresql import ARRAY, HSTORE -class ProcessableMixin(object): - """ Mixin that allows running callables on a value that - is being set to a field. - """ - def __init__(self, *args, **kwargs): - self.processors = kwargs.pop('processors', ()) - super(ProcessableMixin, self).__init__(*args, **kwargs) - - def process_bind_param(self, value, dialect): - for proc in self.processors: - value = proc(value) - return value - - -class LengthLimitedStringMixin(ProcessableMixin): +class LengthLimitedStringMixin(object): """ Mixin for custom string types which may be length limited. """ + _column_name = None + def __init__(self, *args, **kwargs): self.min_length = kwargs.pop('min_length', None) self.max_length = kwargs.pop('max_length', None) @@ -29,36 +17,38 @@ def __init__(self, *args, **kwargs): super(LengthLimitedStringMixin, self).__init__(*args, **kwargs) def process_bind_param(self, value, dialect): - value = super(LengthLimitedStringMixin, self).process_bind_param( - value, dialect) if value is not None: if (self.min_length is not None) and len(value) < self.min_length: - raise ValueError('Value length must be more than {}'.format( - self.min_length)) + raise ValueError( + 'Field `{}`: Value length must be more than {}'.format( + self._column_name, self.min_length)) if (self.max_length is not None) and len(value) > self.max_length: - raise ValueError('Value length must be less than {}'.format( - self.max_length)) + raise ValueError( + 'Field `{}`: Value length must be less than {}'.format( + self._column_name, self.max_length)) return value -class SizeLimitedNumberMixin(ProcessableMixin): +class SizeLimitedNumberMixin(object): """ Mixin for custom string types which may be size limited. """ + _column_name = None + def __init__(self, *args, **kwargs): self.min_value = kwargs.pop('min_value', None) self.max_value = kwargs.pop('max_value', None) super(SizeLimitedNumberMixin, self).__init__(*args, **kwargs) def process_bind_param(self, value, dialect): - value = super(SizeLimitedNumberMixin, self).process_bind_param( - value, dialect) if value is None: return value if (self.min_value is not None) and value < self.min_value: - raise ValueError('Value must be bigger than {}'.format( - self.min_value)) + raise ValueError( + 'Field `{}`: Value must be bigger than {}'.format( + self._column_name, self.min_value)) if (self.max_value is not None) and value > self.max_value: - raise ValueError('Value must be less than {}'.format( - self.max_value)) + raise ValueError( + 'Field `{}`: Value must be less than {}'.format( + self._column_name, self.max_value)) return value @@ -109,24 +99,25 @@ class LimitedNumeric(SizeLimitedNumberMixin, types.TypeDecorator): # Types that support running processors -class ProcessableDateTime(ProcessableMixin, types.TypeDecorator): +class ProcessableDateTime(types.TypeDecorator): impl = types.DateTime -class ProcessableBoolean(ProcessableMixin, types.TypeDecorator): +class ProcessableBoolean(types.TypeDecorator): impl = types.Boolean -class ProcessableDate(ProcessableMixin, types.TypeDecorator): +class ProcessableDate(types.TypeDecorator): impl = types.Date -class ProcessableChoice(ProcessableMixin, types.TypeDecorator): +class ProcessableChoice(types.TypeDecorator): """ Type that represents value from a particular set of choices. Value may be any number of choices from a provided set of valid choices. """ + _column_name = None impl = types.String def __init__(self, *args, **kwargs): @@ -136,40 +127,36 @@ def __init__(self, *args, **kwargs): super(ProcessableChoice, self).__init__(*args, **kwargs) def process_bind_param(self, value, dialect): - value = super(ProcessableChoice, self).process_bind_param( - value, dialect) if (value is not None) and (value not in self.choices): - err = 'Got an invalid choice `{}`. Valid choices: ({})'.format( - value, ', '.join(self.choices)) - raise ValueError(err) + err = 'Field `{}`: Got an invalid choice `{}`. Valid choices: ({})' + err_ctx = [self._column_name, value, ', '.join(self.choices)] + raise ValueError(err.format(*err_ctx)) return value -class ProcessableInterval(ProcessableMixin, types.TypeDecorator): +class ProcessableInterval(types.TypeDecorator): impl = types.Interval def process_bind_param(self, value, dialect): """ Convert seconds(int) :value: to `datetime.timedelta` instance. """ - value = super(ProcessableInterval, self).process_bind_param( - value, dialect) if isinstance(value, int): value = datetime.timedelta(seconds=value) return value -class ProcessableLargeBinary(ProcessableMixin, types.TypeDecorator): +class ProcessableLargeBinary(types.TypeDecorator): impl = types.LargeBinary -class ProcessablePickleType(ProcessableMixin, types.TypeDecorator): +class ProcessablePickleType(types.TypeDecorator): impl = types.PickleType -class ProcessableTime(ProcessableMixin, types.TypeDecorator): +class ProcessableTime(types.TypeDecorator): impl = types.Time -class ProcessableDict(ProcessableMixin, types.TypeDecorator): +class ProcessableDict(types.TypeDecorator): """ Represents a dictionary of values. @@ -192,8 +179,6 @@ def load_dialect_impl(self, dialect): return dialect.type_descriptor(types.UnicodeText) def process_bind_param(self, value, dialect): - value = super(ProcessableDict, self).process_bind_param( - value, dialect) if dialect.name == 'postgresql': return value if value is not None: @@ -208,7 +193,7 @@ def process_result_value(self, value, dialect): return value -class ProcessableChoiceArray(ProcessableMixin, types.TypeDecorator): +class ProcessableChoiceArray(types.TypeDecorator): """ Represents a list of values. If 'postgresql' is used, postgress.ARRAY type is used for db column @@ -217,6 +202,7 @@ class ProcessableChoiceArray(ProcessableMixin, types.TypeDecorator): Supports providing :choices: argument which limits the set of values that may be stored in this field. """ + _column_name = None impl = ARRAY def __init__(self, *args, **kwargs): @@ -250,14 +236,13 @@ def _validate_choices(self, value): invalid_choices = set(value) - set(self.choices) if invalid_choices: - raise ValueError( - 'Got invalid choices: ({}). Valid choices: ({})'.format( - ', '.join(invalid_choices), ', '.join(self.choices))) + err = 'Field `{}`: Got invalid choices: ({}). Valid choices: ({})' + err_ctx = [self._column_name, ', '.join(invalid_choices), + ', '.join(self.choices)] + raise ValueError(err.format(*err_ctx)) return value def process_bind_param(self, value, dialect): - value = super(ProcessableChoiceArray, self).process_bind_param( - value, dialect) value = self._validate_choices(value) if dialect.name == 'postgresql': return value diff --git a/setup.py b/setup.py index 0f6535b..71d4f5e 100644 --- a/setup.py +++ b/setup.py @@ -9,13 +9,13 @@ 'sqlalchemy_utils', 'elasticsearch', 'pyramid_tm', - 'nefertari==0.3.0' + 'nefertari==0.3.1' ] setup( name='nefertari_sqla', - version="0.2.1", + version="0.2.2", description='sqla engine for nefertari', classifiers=[ "Programming Language :: Python",