diff --git a/newsroom/agenda/views.py b/newsroom/agenda/views.py index 629225b30..d44d4790a 100644 --- a/newsroom/agenda/views.py +++ b/newsroom/agenda/views.py @@ -14,7 +14,7 @@ from newsroom.products.products import get_products_by_company from newsroom.template_filters import is_admin_or_internal, is_admin from newsroom.topics import get_company_folders, get_user_folders, get_user_topics -from newsroom.navigations.navigations import get_navigations_by_company +from newsroom.navigations.navigations import get_navigations from newsroom.auth import get_company, get_user, get_user_id, get_user_required from newsroom.decorator import login_required, section from newsroom.utils import ( @@ -132,13 +132,7 @@ def get_view_data() -> Dict: for f in app.download_formatters.values() if "agenda" in f["types"] ], - "navigations": get_navigations_by_company( - company, - product_type="agenda", - events_only=company.get("events_only", False), - ) - if company - else [], + "navigations": get_navigations(user, company, "agenda"), "saved_items": get_resource_service("agenda").get_saved_items_count(), "events_only": company.get("events_only", False) if company else False, "restrict_coverage_info": company.get("restrict_coverage_info", False) if company else False, diff --git a/newsroom/navigations/navigations.py b/newsroom/navigations/navigations.py index 0760653f3..33ea90303 100644 --- a/newsroom/navigations/navigations.py +++ b/newsroom/navigations/navigations.py @@ -1,9 +1,10 @@ +from typing import List, Optional import newsroom from newsroom.products.products import get_products_by_company, get_products_by_user import superdesk from newsroom.utils import is_admin -from newsroom.types import Company, UserData +from newsroom.types import Company, Navigation, UserData class NavigationsResource(newsroom.Resource): @@ -38,14 +39,19 @@ def on_delete(self, doc): superdesk.get_resource_service("products").patch(product["_id"], product) -def get_navigations_by_company(company: Company, product_type="wire", events_only=False): +def get_navigations(user: Optional[UserData], company: Optional[Company], product_type="wire") -> List[Navigation]: """ - Returns list of navigations for given company id - Navigations will contain the list of product ids + Returns list of navigations for given user and company """ - products = get_products_by_company(company, None, product_type, True) + if user and is_admin(user): + return list(superdesk.get_resource_service("navigations").get(req=None, lookup={"product_type": product_type})) + + products = [] + if company: + products += get_products_by_company(company, None, product_type, True) + if user: + products += get_products_by_user(user, product_type, None) - # Get the navigation ids used across products navigation_ids = [] for p in products: if p.get("navigations"): @@ -53,6 +59,14 @@ def get_navigations_by_company(company: Company, product_type="wire", events_onl return get_navigations_by_ids(navigation_ids) +def get_navigations_by_company(company: Company, product_type="wire"): + """ + Returns list of navigations for given company id + Navigations will contain the list of product ids + """ + return get_navigations(None, company, product_type) + + def get_navigations_by_ids(navigation_ids): """ Returns the list of navigations for navigation_ids @@ -65,22 +79,3 @@ def get_navigations_by_ids(navigation_ids): req=None, lookup={"_id": {"$in": navigation_ids}, "is_enabled": True} ) ) - - -def get_navigations_by_user(user: UserData, product_type="wire", events_only=False): - """ - Returns list of navigations for given user id - Navigations will contain the list of product ids - """ - - if is_admin(user): - return list(superdesk.get_resource_service("navigations").get(req=None, lookup={"product_type": product_type})) - - products = get_products_by_user(user, product_type, None) - - # Get the navigation ids used across products - navigation_ids = [] - for p in products: - if p.get("navigations"): - navigation_ids.extend(p["navigations"]) - return get_navigations_by_ids(navigation_ids) diff --git a/newsroom/types.py b/newsroom/types.py index b1b7759f5..981cb938a 100644 --- a/newsroom/types.py +++ b/newsroom/types.py @@ -266,3 +266,7 @@ class DashboardCard(TypedDict): Article = Dict[str, Any] + + +class Navigation(Entity): + name: str diff --git a/newsroom/wire/views.py b/newsroom/wire/views.py index c9c53673b..a21cde469 100644 --- a/newsroom/wire/views.py +++ b/newsroom/wire/views.py @@ -17,7 +17,7 @@ from newsroom.auth.utils import check_user_has_products, is_valid_session from newsroom.cards import get_card_size, get_card_type -from newsroom.navigations.navigations import get_navigations_by_user, get_navigations_by_company +from newsroom.navigations.navigations import get_navigations from newsroom.products.products import get_products_by_company from newsroom.wire import blueprint from newsroom.wire.utils import update_action_list @@ -109,8 +109,7 @@ def get_view_data() -> Dict: for f in app.download_formatters.values() if "wire" in f["types"] ], - "navigations": (get_navigations_by_user(user, "wire") if user else []) - + (get_navigations_by_company(company, "wire") if company else []), + "navigations": get_navigations(user, company, "wire"), "products": products, "saved_items": get_bookmarks_count(user["_id"], "wire"), "context": "wire", diff --git a/tests/core/test_navigations.py b/tests/core/test_navigations.py index 057c6cfe3..ab710fc34 100644 --- a/tests/core/test_navigations.py +++ b/tests/core/test_navigations.py @@ -4,13 +4,16 @@ from newsroom.navigations.views import add_remove_products_for_navigation from newsroom.products.products import get_products_by_navigation +from newsroom.products.views import get_product_ref from newsroom.tests.users import test_login_succeeds_for_admin # noqa from newsroom.tests.fixtures import COMPANY_1_ID -from newsroom.navigations.navigations import get_navigations_by_company, get_navigations_by_user +from newsroom.navigations.navigations import get_navigations +from newsroom.types import Product from tests.core.utils import add_company_products NAV_ID = ObjectId("59b4c5c61d41c8d736852fbf") +AGENDA_NAV_ID = ObjectId() @fixture(autouse=True) @@ -26,6 +29,7 @@ def navigations(app): "is_enabled": True, }, { + "_id": AGENDA_NAV_ID, "name": "Calendar", "product_type": "agenda", "is_enabled": True, @@ -232,9 +236,9 @@ def test_get_agenda_navigations_by_company_returns_ordered(client, app): test_login_succeeds_for_admin(client) company = app.data.find_one("companies", req=None, _id=COMPANY_1_ID) - navigations = get_navigations_by_company(company, "agenda") + navigations = get_navigations(None, company, "agenda") assert navigations[0].get("name") == "Uber" - navigations = get_navigations_by_company(company, "wire") + navigations = get_navigations(None, company, "wire") assert navigations[0].get("name") == "Sport" @@ -277,11 +281,47 @@ def test_get_products_by_navigation_caching(app): assert 0 == len(get_products_by_navigation([nav_id], "wire")) -def test_get_navigations_by_user_for_admin(admin): - navigations = get_navigations_by_user(admin, "wire") +def test_get_navigations_for_admin(admin): + navigations = get_navigations(admin, None, "wire") assert 1 == len(navigations) assert "Sport" == navigations[0]["name"] - navigations = get_navigations_by_user(admin, "agenda") + navigations = get_navigations(admin, None, "agenda") + assert 1 == len(navigations) + assert "Calendar" == navigations[0]["name"] + + +def test_get_navigations_for_user(public_user, public_company, app): + navigations = get_navigations(public_user, public_company, "wire") + assert 0 == len(navigations) + + navigations = get_navigations(public_user, public_company, "agenda") + assert 0 == len(navigations) + + products = [ + Product( + _id=ObjectId(), + name="Wire", + navigations=[NAV_ID], + is_enabled=True, + product_type="wire", + ), + Product( + _id=ObjectId(), + name="Agenda", + navigations=[AGENDA_NAV_ID], + is_enabled=True, + product_type="agenda", + ), + ] + + app.data.insert("products", products) + public_user["products"] = [get_product_ref(products[0]), get_product_ref(products[1])] + + navigations = get_navigations(public_user, public_company, "wire") + assert 1 == len(navigations) + assert "Sport" == navigations[0]["name"] + + navigations = get_navigations(public_user, public_company, "agenda") assert 1 == len(navigations) assert "Calendar" == navigations[0]["name"]