diff --git a/tabcmd/commands/server.py b/tabcmd/commands/server.py index 46e77d01..19136141 100644 --- a/tabcmd/commands/server.py +++ b/tabcmd/commands/server.py @@ -61,18 +61,31 @@ def get_items_by_name(logger, item_endpoint, item_name: str, container: Optional req_option.filter.add( TSC.Filter(TSC.RequestOptions.Field.Name, TSC.RequestOptions.Operator.Equals, item_name) ) + + # todo - this doesn't filter if the project is in the top level. + # todo: there is no guarantee that these fields are the same for different content types. + # probably better if we move that type specific logic out to a wrapper + if container: + # the name of the filter field is different if you are finding a project or any other item + if type(item_endpoint).__name__.find("Projects") < 0: + parentField = TSC.RequestOptions.Field.ProjectName + parentValue = container.name + else: + parentField = TSC.RequestOptions.Field.ParentProjectId + parentValue = container.id + logger.debug("filtering for parent with {}".format(parentField)) + + req_option.filter.add(TSC.Filter(parentField, TSC.RequestOptions.Operator.Equals, parentValue)) + all_items, pagination_item = item_endpoint.get(req_option) - if all_items is None or all_items == []: + if pagination_item.total_available == 0: raise TSC.ServerResponseError( code="404", summary=_("errors.xmlapi.not_found"), detail=_("errors.xmlapi.not_found") + ": " + item_log_name, ) - if total_available_items is None: - total_available_items = pagination_item.total_available - total_retrieved_items += len(all_items) logger.debug( @@ -85,14 +98,8 @@ def get_items_by_name(logger, item_endpoint, item_name: str, container: Optional ) ) - if container: - container_id = container.id - logger.debug("Filtering to items in project {}".format(container.id)) - result.extend(list(filter(lambda item: item.project_id == container_id, all_items))) - else: - result.extend(all_items) - - if total_retrieved_items >= total_available_items: + result.extend(all_items) + if total_retrieved_items >= pagination_item.total_available: break page_number = pagination_item.page_number + 1 @@ -167,12 +174,7 @@ def _parse_project_path_to_list(project_path: str): def _get_project_by_name_and_parent(logger, server, project_name: str, parent: Optional[TSC.ProjectItem]): # logger.debug("get by name and parent: {0}, {1}".format(project_name, parent)) # get by name to narrow down the list - projects = Server.get_items_by_name(logger, server.projects, project_name) - if parent is not None: - parent_id = parent.id - for project in projects: - if project.parent_id == parent_id: - return project + projects = Server.get_items_by_name(logger, server.projects, project_name, parent) return projects[0] @staticmethod diff --git a/tests/commands/test_server.py b/tests/commands/test_server.py index b8a7d293..aaf287e5 100644 --- a/tests/commands/test_server.py +++ b/tests/commands/test_server.py @@ -1,3 +1,4 @@ +import pytest import unittest from unittest.mock import MagicMock, patch import tableauserverclient as TSC @@ -72,30 +73,6 @@ def test_get_items_by_name_with_container(self, MockFilter, MockRequestOptions): logger.debug.assert_called() item_endpoint.get.assert_called() - @patch("tabcmd.commands.server.TSC.RequestOptions") - @patch("tabcmd.commands.server.TSC.Filter") - def test_get_items_by_name_with_container_no_match(self, MockFilter, MockRequestOptions): - logger = MagicMock() - item_endpoint = MagicMock() - item_name = "test_item" - container = MagicMock() - container.id = "container_id" - - pagination_item = MagicMock() - pagination_item.total_available = 1 - pagination_item.page_number = 1 - pagination_item.page_size = 1 - - item = MagicMock() - item.project_id = "different_container_id" - item_endpoint.get.return_value = ([item], pagination_item) - - result = Server.get_items_by_name(logger, item_endpoint, item_name, container) - - self.assertEqual(result, []) - logger.debug.assert_called() - item_endpoint.get.assert_called() - @patch("tabcmd.commands.server.TSC.RequestOptions") @patch("tabcmd.commands.server.TSC.Filter") def test_get_items_by_name_multiple_pages(self, MockFilter, MockRequestOptions): @@ -134,89 +111,3 @@ def test_get_items_by_name_multiple_pages(self, MockFilter, MockRequestOptions): self.assertEqual(result, [item_1, item_2, item_3]) self.assertEqual(item_endpoint.get.call_count, 3) logger.debug.assert_called() - - @patch("tabcmd.commands.server.TSC.RequestOptions") - @patch("tabcmd.commands.server.TSC.Filter") - def test_get_items_by_name_multiple_pages_with_container(self, MockFilter, MockRequestOptions): - logger = MagicMock() - item_endpoint = MagicMock() - item_name = "test_item" - container = MagicMock() - container.id = "container_id" - - pagination_item_1 = MagicMock() - pagination_item_1.total_available = 3 - pagination_item_1.page_number = 1 - pagination_item_1.page_size = 1 - - pagination_item_2 = MagicMock() - pagination_item_2.total_available = 3 - pagination_item_2.page_number = 2 - pagination_item_2.page_size = 1 - - pagination_item_3 = MagicMock() - pagination_item_3.total_available = 3 - pagination_item_3.page_number = 3 - pagination_item_3.page_size = 1 - - item_1 = MagicMock() - item_1.project_id = "container_id_1" - item_2 = MagicMock() - item_2.project_id = "container_id" - item_3 = MagicMock() - item_3.project_id = "container_id_2" - - item_endpoint.get.side_effect = [ - ([item_1], pagination_item_1), - ([item_2], pagination_item_2), - ([item_3], pagination_item_3), - ] - - result = Server.get_items_by_name(logger, item_endpoint, item_name, container) - - self.assertEqual(result, [item_2]) - self.assertEqual(item_endpoint.get.call_count, 3) - logger.debug.assert_called() - - @patch("tabcmd.commands.server.TSC.RequestOptions") - @patch("tabcmd.commands.server.TSC.Filter") - def test_get_items_by_name_multiple_pages_no_container_match(self, MockFilter, MockRequestOptions): - logger = MagicMock() - item_endpoint = MagicMock() - item_name = "test_item" - container = MagicMock() - container.id = "container_id" - - pagination_item_1 = MagicMock() - pagination_item_1.total_available = 3 - pagination_item_1.page_number = 1 - pagination_item_1.page_size = 1 - - pagination_item_2 = MagicMock() - pagination_item_2.total_available = 3 - pagination_item_2.page_number = 2 - pagination_item_2.page_size = 1 - - pagination_item_3 = MagicMock() - pagination_item_3.total_available = 3 - pagination_item_3.page_number = 3 - pagination_item_3.page_size = 1 - - item_1 = MagicMock() - item_1.project_id = "different_container_id_1" - item_2 = MagicMock() - item_2.project_id = "different_container_id_2" - item_3 = MagicMock() - item_3.project_id = "different_container_id_3" - - item_endpoint.get.side_effect = [ - ([item_1], pagination_item_1), - ([item_2], pagination_item_2), - ([item_3], pagination_item_3), - ] - - result = Server.get_items_by_name(logger, item_endpoint, item_name, container) - - self.assertEqual(result, []) - self.assertEqual(item_endpoint.get.call_count, 3) - logger.debug.assert_called()