From 0075f1788ddd968c00f439c083c9926cfbd80dfe Mon Sep 17 00:00:00 2001 From: Jac Fitzgerald Date: Sat, 21 Dec 2024 00:16:17 -0800 Subject: [PATCH 1/2] fix for error when item is not in first page I pushed the parent-project filter into the request so we only get the results that match. We should rename this method [and others] to make it clear they will return one item or throw an error. --- tabcmd/commands/server.py | 22 ++++--- tests/commands/test_server.py | 110 +--------------------------------- 2 files changed, 11 insertions(+), 121 deletions(-) diff --git a/tabcmd/commands/server.py b/tabcmd/commands/server.py index 46e77d01..018e9737 100644 --- a/tabcmd/commands/server.py +++ b/tabcmd/commands/server.py @@ -61,18 +61,22 @@ def get_items_by_name(logger, item_endpoint, item_name: str, container: Optional req_option.filter.add( TSC.Filter(TSC.RequestOptions.Field.Name, TSC.RequestOptions.Operator.Equals, item_name) ) + if container: + req_option.filter.add( + TSC.Filter( + TSC.RequestOptions.Field.ParentProjectId, TSC.RequestOptions.Operator.Equals, container.id + ) + ) + all_items, pagination_item = item_endpoint.get(req_option) - if all_items is None or all_items == []: + if pagination_item.total_available == 0: raise TSC.ServerResponseError( code="404", summary=_("errors.xmlapi.not_found"), detail=_("errors.xmlapi.not_found") + ": " + item_log_name, ) - if total_available_items is None: - total_available_items = pagination_item.total_available - total_retrieved_items += len(all_items) logger.debug( @@ -85,14 +89,8 @@ def get_items_by_name(logger, item_endpoint, item_name: str, container: Optional ) ) - if container: - container_id = container.id - logger.debug("Filtering to items in project {}".format(container.id)) - result.extend(list(filter(lambda item: item.project_id == container_id, all_items))) - else: - result.extend(all_items) - - if total_retrieved_items >= total_available_items: + result.extend(all_items) + if total_retrieved_items >= pagination_item.total_available: break page_number = pagination_item.page_number + 1 diff --git a/tests/commands/test_server.py b/tests/commands/test_server.py index b8a7d293..6f8ed058 100644 --- a/tests/commands/test_server.py +++ b/tests/commands/test_server.py @@ -1,3 +1,4 @@ +import pytest import unittest from unittest.mock import MagicMock, patch import tableauserverclient as TSC @@ -72,29 +73,6 @@ def test_get_items_by_name_with_container(self, MockFilter, MockRequestOptions): logger.debug.assert_called() item_endpoint.get.assert_called() - @patch("tabcmd.commands.server.TSC.RequestOptions") - @patch("tabcmd.commands.server.TSC.Filter") - def test_get_items_by_name_with_container_no_match(self, MockFilter, MockRequestOptions): - logger = MagicMock() - item_endpoint = MagicMock() - item_name = "test_item" - container = MagicMock() - container.id = "container_id" - - pagination_item = MagicMock() - pagination_item.total_available = 1 - pagination_item.page_number = 1 - pagination_item.page_size = 1 - - item = MagicMock() - item.project_id = "different_container_id" - item_endpoint.get.return_value = ([item], pagination_item) - - result = Server.get_items_by_name(logger, item_endpoint, item_name, container) - - self.assertEqual(result, []) - logger.debug.assert_called() - item_endpoint.get.assert_called() @patch("tabcmd.commands.server.TSC.RequestOptions") @patch("tabcmd.commands.server.TSC.Filter") @@ -134,89 +112,3 @@ def test_get_items_by_name_multiple_pages(self, MockFilter, MockRequestOptions): self.assertEqual(result, [item_1, item_2, item_3]) self.assertEqual(item_endpoint.get.call_count, 3) logger.debug.assert_called() - - @patch("tabcmd.commands.server.TSC.RequestOptions") - @patch("tabcmd.commands.server.TSC.Filter") - def test_get_items_by_name_multiple_pages_with_container(self, MockFilter, MockRequestOptions): - logger = MagicMock() - item_endpoint = MagicMock() - item_name = "test_item" - container = MagicMock() - container.id = "container_id" - - pagination_item_1 = MagicMock() - pagination_item_1.total_available = 3 - pagination_item_1.page_number = 1 - pagination_item_1.page_size = 1 - - pagination_item_2 = MagicMock() - pagination_item_2.total_available = 3 - pagination_item_2.page_number = 2 - pagination_item_2.page_size = 1 - - pagination_item_3 = MagicMock() - pagination_item_3.total_available = 3 - pagination_item_3.page_number = 3 - pagination_item_3.page_size = 1 - - item_1 = MagicMock() - item_1.project_id = "container_id_1" - item_2 = MagicMock() - item_2.project_id = "container_id" - item_3 = MagicMock() - item_3.project_id = "container_id_2" - - item_endpoint.get.side_effect = [ - ([item_1], pagination_item_1), - ([item_2], pagination_item_2), - ([item_3], pagination_item_3), - ] - - result = Server.get_items_by_name(logger, item_endpoint, item_name, container) - - self.assertEqual(result, [item_2]) - self.assertEqual(item_endpoint.get.call_count, 3) - logger.debug.assert_called() - - @patch("tabcmd.commands.server.TSC.RequestOptions") - @patch("tabcmd.commands.server.TSC.Filter") - def test_get_items_by_name_multiple_pages_no_container_match(self, MockFilter, MockRequestOptions): - logger = MagicMock() - item_endpoint = MagicMock() - item_name = "test_item" - container = MagicMock() - container.id = "container_id" - - pagination_item_1 = MagicMock() - pagination_item_1.total_available = 3 - pagination_item_1.page_number = 1 - pagination_item_1.page_size = 1 - - pagination_item_2 = MagicMock() - pagination_item_2.total_available = 3 - pagination_item_2.page_number = 2 - pagination_item_2.page_size = 1 - - pagination_item_3 = MagicMock() - pagination_item_3.total_available = 3 - pagination_item_3.page_number = 3 - pagination_item_3.page_size = 1 - - item_1 = MagicMock() - item_1.project_id = "different_container_id_1" - item_2 = MagicMock() - item_2.project_id = "different_container_id_2" - item_3 = MagicMock() - item_3.project_id = "different_container_id_3" - - item_endpoint.get.side_effect = [ - ([item_1], pagination_item_1), - ([item_2], pagination_item_2), - ([item_3], pagination_item_3), - ] - - result = Server.get_items_by_name(logger, item_endpoint, item_name, container) - - self.assertEqual(result, []) - self.assertEqual(item_endpoint.get.call_count, 3) - logger.debug.assert_called() From ce4048ee0bb3bbf596e6fd76fb72bfd5410b6e20 Mon Sep 17 00:00:00 2001 From: Jac Fitzgerald Date: Sat, 21 Dec 2024 01:38:38 -0800 Subject: [PATCH 2/2] fix query filter for projects --- tabcmd/commands/server.py | 26 +++++++++++++++----------- tests/commands/test_server.py | 1 - 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/tabcmd/commands/server.py b/tabcmd/commands/server.py index 018e9737..19136141 100644 --- a/tabcmd/commands/server.py +++ b/tabcmd/commands/server.py @@ -61,12 +61,21 @@ def get_items_by_name(logger, item_endpoint, item_name: str, container: Optional req_option.filter.add( TSC.Filter(TSC.RequestOptions.Field.Name, TSC.RequestOptions.Operator.Equals, item_name) ) + + # todo - this doesn't filter if the project is in the top level. + # todo: there is no guarantee that these fields are the same for different content types. + # probably better if we move that type specific logic out to a wrapper if container: - req_option.filter.add( - TSC.Filter( - TSC.RequestOptions.Field.ParentProjectId, TSC.RequestOptions.Operator.Equals, container.id - ) - ) + # the name of the filter field is different if you are finding a project or any other item + if type(item_endpoint).__name__.find("Projects") < 0: + parentField = TSC.RequestOptions.Field.ProjectName + parentValue = container.name + else: + parentField = TSC.RequestOptions.Field.ParentProjectId + parentValue = container.id + logger.debug("filtering for parent with {}".format(parentField)) + + req_option.filter.add(TSC.Filter(parentField, TSC.RequestOptions.Operator.Equals, parentValue)) all_items, pagination_item = item_endpoint.get(req_option) @@ -165,12 +174,7 @@ def _parse_project_path_to_list(project_path: str): def _get_project_by_name_and_parent(logger, server, project_name: str, parent: Optional[TSC.ProjectItem]): # logger.debug("get by name and parent: {0}, {1}".format(project_name, parent)) # get by name to narrow down the list - projects = Server.get_items_by_name(logger, server.projects, project_name) - if parent is not None: - parent_id = parent.id - for project in projects: - if project.parent_id == parent_id: - return project + projects = Server.get_items_by_name(logger, server.projects, project_name, parent) return projects[0] @staticmethod diff --git a/tests/commands/test_server.py b/tests/commands/test_server.py index 6f8ed058..aaf287e5 100644 --- a/tests/commands/test_server.py +++ b/tests/commands/test_server.py @@ -73,7 +73,6 @@ def test_get_items_by_name_with_container(self, MockFilter, MockRequestOptions): logger.debug.assert_called() item_endpoint.get.assert_called() - @patch("tabcmd.commands.server.TSC.RequestOptions") @patch("tabcmd.commands.server.TSC.Filter") def test_get_items_by_name_multiple_pages(self, MockFilter, MockRequestOptions):