diff --git a/setup.py b/setup.py index 12245610c8..c4649f9977 100644 --- a/setup.py +++ b/setup.py @@ -40,13 +40,12 @@ "requests>=2.7.0,<3.0", "boto3>=1.26,<2.0", "websockets>=10.3,<10.4", - "mongolock>=1.3.4,<1.4", "PyYAML>=6.0.1", "lxml>=5.2.2,<5.3", "lxml_html_clean>=0.1.1,<0.2", "python-twitter>=3.5,<3.6", "chardet<6.0", - "pymongo>=3.8,<3.12", + "pymongo>=4.7.3,<4.8", "croniter<2.1", "python-dateutil<2.10", "unidecode>=0.04.21,<=1.3.8", @@ -67,6 +66,8 @@ "itsdangerous>=1.1,<2.0", "pymemcache>=4.0,<4.1", "xmlsec>=1.3.13,<1.3.15", + # Async libraries + "motor>=3.4.0,<4.0", ] package_data = { diff --git a/superdesk/emails/__init__.py b/superdesk/emails/__init__.py index 7e0b1e68b3..1c18a9c850 100644 --- a/superdesk/emails/__init__.py +++ b/superdesk/emails/__init__.py @@ -177,7 +177,7 @@ def send_activity_emails(activity, recipients): subject = render_template("notification_subject.txt", notification=notification) send_email.delay(subject=subject, sender=admins[0], recipients=recipients, text_body=text_body, html_body=html_body) - email_timestamps.update({"_id": message_id}, {"_id": message_id, "_created": now}, upsert=True) + email_timestamps.update_one({"_id": message_id}, {"$set": {"_id": message_id, "_created": now}}, upsert=True) def send_article_killed_email(article, recipients, transmitted_at): diff --git a/superdesk/eve_backend.py b/superdesk/eve_backend.py index 83f96c21d8..955213a3f8 100644 --- a/superdesk/eve_backend.py +++ b/superdesk/eve_backend.py @@ -126,9 +126,14 @@ def get(self, endpoint_name, req, lookup, **kwargs): backend = self._lookup_backend(endpoint_name, fallback=True) is_mongo = self._backend(endpoint_name) == backend - cursor, count = backend.find(endpoint_name, req, lookup, perform_count=req.if_modified_since) + cursor, _ = backend.find(endpoint_name, req, lookup, perform_count=False) - if req.if_modified_since and count: + try: + has_items = cursor[0] is not None + except IndexError: + has_items = False + + if req.if_modified_since and has_items: # fetch all items, not just updated req.if_modified_since = None cursor, count = backend.find(endpoint_name, req, lookup, perform_count=False) @@ -137,7 +142,7 @@ def get(self, endpoint_name, req, lookup, **kwargs): if is_mongo and source_config.get("collation"): cursor.collation(Collation(locale=app.config.get("MONGO_LOCALE", "en_US"))) - self._cursor_hook(cursor=cursor, req=req) + self._cursor_hook(cursor=cursor, endpoint_name=endpoint_name, req=req, lookup=lookup) return cursor def get_from_mongo(self, endpoint_name, req, lookup, perform_count=False): @@ -151,8 +156,8 @@ def get_from_mongo(self, endpoint_name, req, lookup, perform_count=False): """ req.if_modified_since = None backend = self._backend(endpoint_name) - cursor, _ = backend.find(endpoint_name, req, lookup, perform_count=perform_count) - self._cursor_hook(cursor=cursor, req=req) + cursor, _ = backend.find(endpoint_name, req, lookup, perform_count=False) + self._cursor_hook(cursor=cursor, endpoint_name=endpoint_name, req=req, lookup=lookup) return cursor def find_and_modify(self, endpoint_name, **kwargs): @@ -166,7 +171,7 @@ def find_and_modify(self, endpoint_name, **kwargs): if kwargs.get("query"): kwargs["query"] = backend._mongotize(kwargs["query"], endpoint_name) - result = backend.driver.db[endpoint_name].find_and_modify(**kwargs) + result = backend.driver.db[endpoint_name].find_one_and_update(**kwargs) cache.clean([endpoint_name]) return result @@ -464,9 +469,32 @@ def _set_parent(self, endpoint_name, doc, lookup): if parent: lookup["parent"] = parent - def _cursor_hook(self, cursor, req): + def construct_count_function(self, resource, req, lookup): + backend = self._backend(resource) + + client_sort = backend._convert_sort_request_to_dict(req) + spec = backend._convert_where_request_to_dict(resource, req) + + if lookup: + spec = backend.combine_queries(spec, lookup) + + spec = backend._mongotize(spec, resource) + client_projection = backend._client_projection(req) + + datasource, spec, projection, sort = backend._datasource_ex(resource, spec, client_projection, client_sort) + target = backend.pymongo(resource).db[datasource] + + def count_function(): + return target.count_documents(spec) + + return count_function + + def _cursor_hook(self, cursor, endpoint_name, req, lookup): """Apply additional methods for cursor""" + if not hasattr(cursor, "count"): + setattr(cursor, "count", self.construct_count_function(endpoint_name, req, lookup)) + if not req or not req.args: return diff --git a/superdesk/lock.py b/superdesk/lock.py index 1ffd4d2bfa..b169450b71 100644 --- a/superdesk/lock.py +++ b/superdesk/lock.py @@ -4,7 +4,9 @@ import logging from datetime import datetime -from mongolock import MongoLock, MongoLockException + +# from mongolock import MongoLock, MongoLockException +from superdesk.mongolock import MongoLock, MongoLockException from werkzeug.local import LocalProxy from flask import current_app as app from superdesk.logging import logger diff --git a/superdesk/mongolock.py b/superdesk/mongolock.py new file mode 100644 index 0000000000..931076b582 --- /dev/null +++ b/superdesk/mongolock.py @@ -0,0 +1,129 @@ +# Copied from https://github.com/lorehov/mongolock, to support newer PyMongo lib + +import time +import contextlib +from datetime import datetime, timedelta + +from pymongo import MongoClient +from pymongo.errors import DuplicateKeyError +from pymongo.collection import Collection + + +class MongoLockException(Exception): + pass + + +class MongoLockLocked(Exception): + pass + + +class MongoLock(object): + def __init__(self, host="localhost", db="mongolock", collection="lock", client=None, acquire_retry_step=0.1): + """Create a new instance of MongoLock. + + :Parameters: + - `host` (optional) - use it to manually specify mongodb connection string + - `db` (optional) - db name + - `collection` (optional) - collection name or :class:`pymongo.Collection` instance + - `client` - instance of :class:`MongoClient` or :class:`MongoReplicaSetClient`, + - `acquire_retry_step` (optional)- time in seconds between retries while trying to acquire the lock, + if specified - `host` parameter will be skipped + """ + self.acquire_retry_step = acquire_retry_step + if isinstance(collection, Collection): + self.collection = collection + else: + if client is None: + self.client = client + else: + self.client = MongoClient(host) + self.collection = self.client[db][collection] + + @contextlib.contextmanager + def __call__(self, key, owner, timeout=None, expire=None): + """See `lock` method.""" + if not self.lock(key, owner, timeout, expire): + status = self.get_lock_info(key) + raise MongoLockLocked( + "Timeout, lock owned by {owner} since {ts}, expire time is {expire}".format( + owner=status["owner"], ts=status["created"], expire=status["expire"] + ) + ) + try: + yield + finally: + self.release(key, owner) + + def lock(self, key, owner, timeout=None, expire=None): + """Lock given `key` to `owner`. + + :Parameters: + - `key` - lock name + - `owner` - name of application/component/whatever which asks for lock + - `timeout` (optional) - how long to wait if `key` is locked + - `expire` (optional) - when given, lock will be released after that number of seconds. + + Raises `MongoLockTimeout` if can't achieve a lock before timeout. + """ + expire = datetime.utcnow() + timedelta(seconds=expire) if expire else None + try: + self.collection.insert_one( + {"_id": key, "locked": True, "owner": owner, "created": datetime.utcnow(), "expire": expire} + ) + return True + except DuplicateKeyError: + start_time = datetime.utcnow() + while True: + if self._try_get_lock(key, owner, expire): + return True + + if not timeout or datetime.utcnow() >= start_time + timedelta(seconds=timeout): + return False + time.sleep(self.acquire_retry_step) + + def release(self, key, owner): + """Release lock with given name. + `key` - lock name + `owner` - name of application/component/whatever which held a lock + Raises `MongoLockException` if no such a lock. + """ + status = self.collection.find_and_modify( + {"_id": key, "owner": owner}, {"locked": False, "owner": None, "created": None, "expire": None} + ) + + def get_lock_info(self, key): + """Get lock status.""" + return self.collection.find_one({"_id": key}) + + def is_locked(self, key): + lock_info = self.get_lock_info(key) + return not ( + not lock_info + or not lock_info["locked"] + or (lock_info["expire"] is not None and lock_info["expire"] < datetime.utcnow()) + ) + + def touch(self, key, owner, expire=None): + """Renew lock to avoid expiration.""" + lock = self.collection.find_one({"_id": key, "owner": owner}) + if not lock: + raise MongoLockException("Can't find lock for {key}: {owner}".format(key=key, owner=owner)) + if not lock["expire"]: + return + if not expire: + raise MongoLockException("Can't touch lock without expire for {0}: {1}".format(key, owner)) + expire = datetime.utcnow() + timedelta(seconds=expire) + self.collection.update_one({"_id": key, "owner": owner}, {"$set": {"expire": expire}}) + + def _try_get_lock(self, key, owner, expire): + dtnow = datetime.utcnow() + result = self.collection.update_one( + { + "$or": [ + {"_id": key, "locked": False}, + {"_id": key, "expire": {"$lt": dtnow}}, + ] + }, + {"$set": {"locked": True, "owner": owner, "created": dtnow, "expire": expire}}, + ) + return result and result.acknowledged and result.modified_count == 1 diff --git a/superdesk/services.py b/superdesk/services.py index 90e4e8008a..3d952e0a4b 100644 --- a/superdesk/services.py +++ b/superdesk/services.py @@ -134,8 +134,7 @@ def get_all(self): return self.get_from_mongo(None, {}).sort("_id") def find_and_modify(self, query, update, **kwargs): - res = self.backend.find_and_modify(self.datasource, query=query, update=update, **kwargs) - return res + return self.backend.find_and_modify(self.datasource, filter=query, update=update, **kwargs) def get_all_batch(self, size=500, max_iterations=10000, lookup=None): """Gets all items using multiple queries. diff --git a/superdesk/storage/desk_media_storage.py b/superdesk/storage/desk_media_storage.py index 9e878865d8..e4b5fe63f8 100644 --- a/superdesk/storage/desk_media_storage.py +++ b/superdesk/storage/desk_media_storage.py @@ -10,6 +10,7 @@ from typing import Optional from flask import current_app as app +from flask_babel import _ import logging import json import mimetypes @@ -17,9 +18,11 @@ import bson.errors import gridfs import os.path +import hashlib from eve.io.mongo.media import GridFSMediaStorage +from superdesk.errors import SuperdeskApiError from . import SuperdeskMediaStorage @@ -108,10 +111,29 @@ def put(self, content, filename=None, content_type=None, metadata=None, resource if filename: filename = "{}/{}".format(folder, filename) + if hasattr(content, "read"): + data = content.read() + if hasattr(data, "encode"): + data = data.encode("utf-8") + hash_data = hashlib.md5(data) + if hasattr(content, "seek"): + content.seek(0) + elif isinstance(content, bytes): + hash_data = hashlib.md5(content) + elif isinstance(content, str): + hash_data = hashlib.md5(content.encode("utf-8")) + else: + raise SuperdeskApiError.badRequestError(_("Unsupported content type")) + try: logger.info("Adding file {} to the GridFS".format(filename)) return self.fs(resource).put( - content, content_type=content_type, filename=filename, metadata=metadata, **kwargs + content, + content_type=content_type, + filename=filename, + metadata=metadata, + md5=hash_data.hexdigest(), + **kwargs, ) except gridfs.errors.FileExists: logger.info("File exists filename=%s id=%s" % (filename, kwargs["_id"])) diff --git a/superdesk/tests/__init__.py b/superdesk/tests/__init__.py index 8a99886e84..6f0adfb4ca 100644 --- a/superdesk/tests/__init__.py +++ b/superdesk/tests/__init__.py @@ -167,7 +167,6 @@ def drop_mongo(app): dbname = app.config[name] dbconn = app.data.mongo.pymongo(prefix=prefix).cx dbconn.drop_database(dbname) - dbconn.close() def setup_config(config): @@ -356,6 +355,14 @@ def inner(*a, **kw): def setup(context=None, config=None, app_factory=get_app, reset=False): if not hasattr(setup, "app") or setup.reset or config: + if hasattr(setup, "app"): + # Close all PyMongo Connections (new ones will be created with ``app_factory`` call) + for key, val in setup.app.extensions["pymongo"].items(): + val[0].close() + + if hasattr(setup.app, "async_app"): + setup.app.async_app.stop() + cfg = setup_config(config) setup.app = app_factory(cfg) setup.reset = reset diff --git a/tests/commands/data_updates_test.py b/tests/commands/data_updates_test.py index 44aebda49f..847a5809b9 100644 --- a/tests/commands/data_updates_test.py +++ b/tests/commands/data_updates_test.py @@ -56,7 +56,7 @@ def number_of_data_updates_applied(self): def test_dry_data_update(self): superdesk.commands.data_updates.DEFAULT_DATA_UPDATE_FW_IMPLEMENTATION = """ - count = mongodb_collection.find({}).count() + count = mongodb_collection.count_documents({}) assert count == 0, count """ self.assertEqual(self.number_of_data_updates_applied(), 0) @@ -79,18 +79,18 @@ def test_data_update(self): # create migrations for index in range(40): superdesk.commands.data_updates.DEFAULT_DATA_UPDATE_FW_IMPLEMENTATION = """ - assert mongodb_collection - count = mongodb_collection.find({}).count() + assert mongodb_collection is not None + count = mongodb_collection.count_documents({}) assert count == %d, count - assert mongodb_database + assert mongodb_database is not None """ % ( index ) superdesk.commands.data_updates.DEFAULT_DATA_UPDATE_BW_IMPLEMENTATION = """ - assert mongodb_collection - count = mongodb_collection.find({}).count() + assert mongodb_collection is not None + count = mongodb_collection.count_documents({}) assert count == %d, count - assert mongodb_database + assert mongodb_database is not None """ % ( index + 1 ) diff --git a/tests/content_api_test.py b/tests/content_api_test.py index 5290067fa2..58073bb5dc 100644 --- a/tests/content_api_test.py +++ b/tests/content_api_test.py @@ -42,13 +42,13 @@ def _auth_headers(self, sub=None): def test_publish_to_content_api(self): item = {"guid": "foo", "type": "text", "task": {"desk": "foo"}, "rewrite_of": "bar"} self.content_api.publish(item) - self.assertEqual(1, self.db.items.count()) + self.assertEqual(1, self.db.items.count_documents({})) self.assertNotIn("task", self.db.items.find_one()) self.assertEqual("foo", self.db.items.find_one()["_id"]) item["_current_version"] = "2" self.content_api.publish(item) - self.assertEqual(1, self.db.items.count()) + self.assertEqual(1, self.db.items.count_documents({})) item["_current_version"] = "3" item["headline"] = "foo" @@ -56,7 +56,7 @@ def test_publish_to_content_api(self): self.assertEqual("foo", self.db.items.find_one()["headline"]) self.assertEqual("bar", self.db.items.find_one()["evolvedfrom"]) - self.assertEqual(3, self.db.items_versions.count()) + self.assertEqual(3, self.db.items_versions.count_documents({})) def test_create_keeps_planning_metadata(self): item = { @@ -76,8 +76,8 @@ def test_publish_with_subscriber_ids(self): subscribers = [{"_id": ObjectId()}, {"_id": ObjectId()}] self.content_api.publish(item, subscribers) - self.assertEqual(1, self.db.items.find({"subscribers": str(subscribers[0]["_id"])}).count()) - self.assertEqual(0, self.db.items.find({"subscribers": "foo"}).count()) + self.assertEqual(1, self.db.items.count_documents({"subscribers": str(subscribers[0]["_id"])})) + self.assertEqual(0, self.db.items.count_documents({"subscribers": "foo"})) def test_content_filtering_by_subscriber(self): subscriber = {"_id": "sub1"} @@ -451,9 +451,9 @@ def test_api_block(self): self.app.data.insert("content_filters", [content_filter]) self.content_api.publish({"_id": "foo", "source": "fred", "type": "text", "guid": "foo"}) - self.assertEqual(0, self.db.items.count()) + self.assertEqual(0, self.db.items.count_documents({})) self.content_api.publish({"_id": "bar", "source": "jane", "type": "text", "guid": "bar"}) - self.assertEqual(1, self.db.items.count()) + self.assertEqual(1, self.db.items.count_documents({})) def test_item_versions_api(self): subscriber = {"_id": "sub1"} @@ -521,9 +521,9 @@ def test_publish_kill_to_content_api(self): item["_current_version"] = 2 self.content_api.publish(item, [subscriber]) - self.assertEqual(1, self.db.items.count()) + self.assertEqual(1, self.db.items.count_documents({})) self.assertEqual("canceled", self.db.items.find_one()["pubstatus"]) - self.assertEqual(2, self.db.items_versions.count()) + self.assertEqual(2, self.db.items_versions.count_documents({})) for i in self.db.items_versions.find(): self.assertEqual(i.get("pubstatus"), "canceled") @@ -538,14 +538,14 @@ def test_publish_kill_to_content_api(self): def test_publish_item_with_ancestors(self): item = {"guid": "foo", "type": "text", "task": {"desk": "foo"}, "bookmarks": [ObjectId()]} self.content_api.publish(item) - self.assertEqual(1, self.db.items.count()) + self.assertEqual(1, self.db.items.count_documents({})) self.assertNotIn("ancestors", self.db.items.find_one({"_id": "foo"})) item["guid"] = "bar" item["rewrite_of"] = "foo" self.content_api.publish(item) - self.assertEqual(2, self.db.items.count()) + self.assertEqual(2, self.db.items.count_documents({})) bar = self.db.items.find_one({"_id": "bar"}) self.assertEqual(["foo"], bar.get("ancestors", [])) self.assertEqual("foo", bar.get("evolvedfrom")) @@ -556,7 +556,7 @@ def test_publish_item_with_ancestors(self): item["rewrite_of"] = "bar" self.content_api.publish(item) - self.assertEqual(3, self.db.items.count()) + self.assertEqual(3, self.db.items.count_documents({})) fun = self.db.items.find_one({"_id": "fun"}) self.assertEqual(["foo", "bar"], fun.get("ancestors", [])) self.assertEqual("bar", fun.get("evolvedfrom")) @@ -707,7 +707,7 @@ def test_associated_item_filter_by_subscriber(self): subscriber1 = {"_id": "sub1"} subscriber2 = {"_id": "sub2"} self.content_api.publish(item, [subscriber1, subscriber2]) - self.assertEqual(1, self.db.items.count()) + self.assertEqual(1, self.db.items.count_documents({})) with self.capi.test_client() as c: response = c.get("items/foo", headers=self._auth_headers(subscriber1)) data = json.loads(response.data) diff --git a/tests/datalayer_tests.py b/tests/datalayer_tests.py index c7d21cf221..fc2ddf51a1 100644 --- a/tests/datalayer_tests.py +++ b/tests/datalayer_tests.py @@ -40,7 +40,9 @@ def test_find_with_mongo_query(self): ) self.assertEqual(1, service.find({"resource": {"$in": ["foo"]}}).count()) - self.assertEqual(1, service.find({}, max_results=1).count(True)) + # We no longer support ``with_limit_and_skip`` attribute with count + # it was only supported in MongoCursor anyway + self.assertEqual(2, service.find({}, max_results=1).count()) def test_set_custom_etag_on_create(self): service = superdesk.get_resource_service("activity") diff --git a/tests/io/update_ingest_tests.py b/tests/io/update_ingest_tests.py index d5474a1235..8acc76f8be 100644 --- a/tests/io/update_ingest_tests.py +++ b/tests/io/update_ingest_tests.py @@ -257,7 +257,7 @@ def test_expiring_content_with_files(self): # four files in grid fs current_files = self.app.media.storage().fs("upload").find() - self.assertEqual(4, current_files.count()) + self.assertEqual(4, len(list(current_files))) with patch("superdesk.io.commands.remove_expired_content.utcnow", return_value=now + timedelta(hours=20)): remove = RemoveExpiredContent() @@ -265,7 +265,7 @@ def test_expiring_content_with_files(self): # all gone current_files = self.app.media.storage().fs("upload").find() - self.assertEqual(0, current_files.count()) + self.assertEqual(0, len(list(current_files))) def test_apply_rule_set(self): item = {"body_html": "@@body@@"} @@ -328,7 +328,7 @@ def test_files_dont_duplicate_ingest(self): # 12 files in grid fs current_files = self.app.media.storage().fs("upload").find() - self.assertEqual(12, current_files.count()) + self.assertEqual(12, len(list(current_files))) def test_anpa_category_to_subject_derived_ingest(self): vocab = [ diff --git a/tests/storage/gridfs_media_storage_test.py b/tests/storage/gridfs_media_storage_test.py index c0d437cf10..a211102a83 100644 --- a/tests/storage/gridfs_media_storage_test.py +++ b/tests/storage/gridfs_media_storage_test.py @@ -3,7 +3,7 @@ import eve import bson import unittest -from unittest.mock import Mock +from unittest.mock import Mock, ANY from superdesk.upload import bp, upload_url from superdesk.datalayer import SuperdeskDataLayer from superdesk.storage import SuperdeskGridFSMediaStorage @@ -51,6 +51,7 @@ def test_put_media_with_id(self): "filename": filename, "metadata": None, "_id": _id, + "md5": ANY, } gridfs.put.assert_called_once_with(data, **kwargs) @@ -65,7 +66,12 @@ def test_put_into_folder(self): with self.app.app_context(): self.media.put(data, filename=filename, content_type="text/plain", folder=folder) - kwargs = {"content_type": "text/plain", "filename": "{}/{}".format(folder, filename), "metadata": None} + kwargs = { + "content_type": "text/plain", + "filename": "{}/{}".format(folder, filename), + "metadata": None, + "md5": ANY, + } gridfs.put.assert_called_once_with(data, **kwargs) @@ -117,6 +123,7 @@ def test_mimetype_detect(self): "filename": filename, "metadata": None, "_id": _id, + "md5": ANY, } gridfs.put.assert_called_once_with(content, **kwargs) @@ -133,6 +140,7 @@ def test_mimetype_detect(self): "filename": filename, "metadata": None, "_id": _id, + "md5": ANY, } gridfs.put.assert_called_once_with(content, **kwargs) @@ -148,6 +156,7 @@ def test_mimetype_detect(self): "filename": filename, "metadata": None, "_id": _id, + "md5": ANY, } gridfs.put.assert_called_once_with(content, **kwargs) @@ -165,6 +174,7 @@ def test_mimetype_detect(self): "filename": filename, "metadata": None, "_id": _id, + "md5": ANY, } gridfs.put.assert_called_once_with(content, **kwargs) @@ -180,6 +190,7 @@ def test_mimetype_detect(self): "filename": filename, "metadata": None, "_id": _id, + "md5": ANY, } gridfs.put.assert_called_once_with(content, **kwargs) @@ -195,6 +206,7 @@ def test_mimetype_detect(self): "filename": filename, "metadata": None, "_id": _id, + "md5": ANY, } gridfs.put.assert_called_once_with(content, **kwargs) @@ -210,5 +222,6 @@ def test_mimetype_detect(self): "filename": filename, "metadata": None, "_id": _id, + "md5": ANY, } gridfs.put.assert_called_once_with(content, **kwargs)