diff --git a/core/src/zeit/connector/postgresql.py b/core/src/zeit/connector/postgresql.py index 5aa2cd127e..6188ba9cb1 100644 --- a/core/src/zeit/connector/postgresql.py +++ b/core/src/zeit/connector/postgresql.py @@ -6,6 +6,7 @@ import collections import os import os.path +import secrets import time from gocept.cache.property import TransactionBoundCache @@ -20,6 +21,7 @@ UnicodeText, UniqueConstraint, Uuid, + schema, select, ) from sqlalchemy.dialects.postgresql import JSONB @@ -329,14 +331,43 @@ def move(self, old_id, new_id): self.property_cache.pop(old_id, None) self.body_cache.pop(old_id, None) - def lock(self, uniqueid, principal, until): - pass - - def unlock(self, uniqueid, locktoken=None): - pass - - def locked(self, uniqueid): - pass + def lock(self, id, principal, until): + # XXX should check for existing lock + path = self.session.get(Paths, self._pathkey(id)) + token = secrets.token_hex() + self.session.add( + Lock( + parent_path=path.parent_path, + name=path.name, + id=path.id, + principal=principal, + until=until, + token=token, + ) + ) + return token + + def unlock(self, id, locktoken=None): + path = self.session.get(Paths, self._pathkey(id)) + if path is None: + return + lock = self.session.get(Lock, (path.parent_path, path.name, path.id)) + # XXX is this the stored locktoken in connector or what does the test mean? + if lock and locktoken is None: + locktoken = lock.token + if lock is None or lock.token != locktoken: + return + self.session.delete(lock) + return locktoken + + def locked(self, id): + path = self.session.get(Paths, self._pathkey(id)) + if path is None: + return False + lock = self.session.get(Lock, (path.parent_path, path.name, path.id)) + if lock is None: + return (None, None, False) + return (lock.principal, lock.until, lock is not None) def search(self, attrlist, expr): if ( @@ -394,6 +425,26 @@ class Paths(DBObject): ) +class Lock(DBObject): + __tablename__ = 'locks' + + parent_path = Column(Unicode, primary_key=True) + name = Column(Unicode, primary_key=True) + id = Column(Uuid(as_uuid=False), primary_key=True) + principal = Column(Unicode, nullable=False) + until = Column(TIMESTAMP(timezone=True), nullable=False) + token = Column(Unicode, nullable=False) + + __table_args__ = ( + schema.ForeignKeyConstraint( + (parent_path, name, id), + (Paths.parent_path, Paths.name, Paths.id), + onupdate='CASCADE', + ondelete='CASCADE', + ), + ) + + class Properties(DBObject): __tablename__ = 'properties' diff --git a/core/src/zeit/connector/tests/test_contract.py b/core/src/zeit/connector/tests/test_contract.py index 88ce919127..f131a732ae 100644 --- a/core/src/zeit/connector/tests/test_contract.py +++ b/core/src/zeit/connector/tests/test_contract.py @@ -359,7 +359,7 @@ class ContractMock( class ContractSQL( ContractReadWrite, ContractCopyMove, - # ContractLock, + ContractLock, ContractSearch, zeit.connector.testing.SQLTest, ): @@ -367,5 +367,5 @@ class ContractSQL( copy_inherited_functions(ContractReadWrite, locals()) copy_inherited_functions(ContractCopyMove, locals()) - # copy_inherited_functions(ContractLock, locals()) + copy_inherited_functions(ContractLock, locals()) copy_inherited_functions(ContractSearch, locals())