diff --git a/deploy-service/common/src/main/java/com/pinterest/deployservice/handler/EnvironHandler.java b/deploy-service/common/src/main/java/com/pinterest/deployservice/handler/EnvironHandler.java index ce781dd6ff..e6c7cbab88 100644 --- a/deploy-service/common/src/main/java/com/pinterest/deployservice/handler/EnvironHandler.java +++ b/deploy-service/common/src/main/java/com/pinterest/deployservice/handler/EnvironHandler.java @@ -444,9 +444,8 @@ public void updateHosts(EnvironBean envBean, List hosts, String operator public void updateGroups(EnvironBean envBean, List groups, String operator) throws Exception { // TODO need to check group env conflicts and reject if so - List oldGroupList = groupDAO.getCapacityGroups(envBean.getEnv_id()); Set oldGroups = new HashSet<>(); - oldGroups.addAll(oldGroupList); + oldGroups.addAll(groupDAO.getCapacityGroups(envBean.getEnv_id())); for (String group : groups) { if (!oldGroups.contains(group)) { groupDAO.addGroupCapacity(envBean.getEnv_id(), group); @@ -455,7 +454,7 @@ public void updateGroups(EnvironBean envBean, List groups, String operat } } for (String group : oldGroups) { - if (group == envBean.getCluster_name()) { + if (group.equals(envBean.getCluster_name())) { LOG.info("Skipping implicit group {}", group); continue; } diff --git a/deploy-service/common/src/test/java/com/pinterest/deployservice/handler/EnvironHandlerTest.java b/deploy-service/common/src/test/java/com/pinterest/deployservice/handler/EnvironHandlerTest.java index c7cc557e7d..1fb10c9275 100644 --- a/deploy-service/common/src/test/java/com/pinterest/deployservice/handler/EnvironHandlerTest.java +++ b/deploy-service/common/src/test/java/com/pinterest/deployservice/handler/EnvironHandlerTest.java @@ -20,10 +20,12 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.Matchers.anyString; import static org.mockito.Matchers.eq; -import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.common.collect.ImmutableList; import com.pinterest.deployservice.ServiceContext; import com.pinterest.deployservice.bean.EnvType; import com.pinterest.deployservice.bean.EnvironBean; @@ -31,6 +33,7 @@ import com.pinterest.deployservice.bean.HostState; import com.pinterest.deployservice.dao.AgentDAO; import com.pinterest.deployservice.dao.EnvironDAO; +import com.pinterest.deployservice.dao.GroupDAO; import com.pinterest.deployservice.dao.HostDAO; import java.sql.SQLException; import java.util.Arrays; @@ -41,31 +44,38 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; class EnvironHandlerTest { + private static final String TEST_OPERATOR = "operator"; + private static final String TEST_CLUSTER_NAME = "clusterName"; private static final String DEFAULT_HOST_ID = "hostId"; private EnvironHandler environHandler; - private HostDAO mockHostDAO; - private AgentDAO mockAgentDAO; - private EnvironDAO environDAO; + @Mock private HostDAO mockHostDAO; + @Mock private AgentDAO mockAgentDAO; + @Mock private EnvironDAO environDAO; + @Mock private GroupDAO groupDAO; + EnvironBean testEnvBean; private List hostIds = Arrays.asList("hostId1", "hostId2"); - private ServiceContext createMockServiceContext() throws Exception { - mockHostDAO = mock(HostDAO.class); - mockAgentDAO = mock(AgentDAO.class); - environDAO = mock(EnvironDAO.class); - + private ServiceContext createMockServiceContext() { ServiceContext serviceContext = new ServiceContext(); serviceContext.setHostDAO(mockHostDAO); serviceContext.setAgentDAO(mockAgentDAO); serviceContext.setEnvironDAO(environDAO); + serviceContext.setGroupDAO(groupDAO); return serviceContext; } @BeforeEach - void setUp() throws Exception { + void setUp() { + MockitoAnnotations.initMocks(this); + testEnvBean = new EnvironBean(); + testEnvBean.setEnv_id("envId"); + testEnvBean.setCluster_name(TEST_CLUSTER_NAME); environHandler = new EnvironHandler(createMockServiceContext()); } @@ -91,15 +101,14 @@ void stopServiceOnHost_withoutReplaceHost_hostBeanStateIsPendingTerminateNoRepla @Test void updateStage_type_enables_private_build() throws Exception { ArgumentCaptor argument = ArgumentCaptor.forClass(EnvironBean.class); - EnvironBean envBean = new EnvironBean(); - envBean.setStage_type(EnvType.DEV); - environHandler.createEnvStage(envBean, "Anonymous"); + testEnvBean.setStage_type(EnvType.DEV); + environHandler.createEnvStage(testEnvBean, "Anonymous"); verify(environDAO).insert(argument.capture()); assertEquals(true, argument.getValue().getAllow_private_build()); } @Test - void ensureHostsOwnedByEnv_noMainEnv() throws Exception { + void ensureHostsOwnedByEnv_noMainEnv() { assertThrows( NotFoundException.class, () -> environHandler.ensureHostsOwnedByEnv(new EnvironBean(), hostIds)); @@ -107,9 +116,7 @@ void ensureHostsOwnedByEnv_noMainEnv() throws Exception { @Test void ensureHostsOwnedByEnv_differentMainEnv() throws Exception { - EnvironBean envBean = new EnvironBean(); - envBean.setEnv_id("envId"); - when(environDAO.getMainEnvByHostId(anyString())).thenReturn(envBean); + when(environDAO.getMainEnvByHostId(anyString())).thenReturn(testEnvBean); assertThrows( ForbiddenException.class, () -> environHandler.ensureHostsOwnedByEnv(new EnvironBean(), hostIds)); @@ -117,10 +124,8 @@ void ensureHostsOwnedByEnv_differentMainEnv() throws Exception { @Test void ensureHostsOwnedByEnv_sameMainEnv() throws Exception { - EnvironBean envBean = new EnvironBean(); - envBean.setEnv_id("envId"); - when(environDAO.getMainEnvByHostId(anyString())).thenReturn(envBean); - assertDoesNotThrow(() -> environHandler.ensureHostsOwnedByEnv(envBean, hostIds)); + when(environDAO.getMainEnvByHostId(anyString())).thenReturn(testEnvBean); + assertDoesNotThrow(() -> environHandler.ensureHostsOwnedByEnv(testEnvBean, hostIds)); } @Test @@ -130,4 +135,56 @@ void ensureHostsOwnedByEnv_sqlException() throws Exception { WebApplicationException.class, () -> environHandler.ensureHostsOwnedByEnv(new EnvironBean(), hostIds)); } + + @Test + void updateGroups_addNewGroups() throws Exception { + List groups = ImmutableList.of("group1", "group2"); + when(groupDAO.getCapacityGroups(testEnvBean.getEnv_id())) + .thenReturn(ImmutableList.of(TEST_CLUSTER_NAME)); + + environHandler.updateGroups(testEnvBean, groups, TEST_OPERATOR); + + ArgumentCaptor groupCaptor = ArgumentCaptor.forClass(String.class); + + verify(groupDAO, times(2)).addGroupCapacity(anyString(), groupCaptor.capture()); + verify(groupDAO, never()).removeGroupCapacity(anyString(), anyString()); + + List capturedGroups = groupCaptor.getAllValues(); + + for (int i = 0; i < groups.size(); i++) { + assertEquals(groups.get(i), capturedGroups.get(i)); + } + } + + @Test + void updateGroups_addEmptyGroups() throws Exception { + List groups = ImmutableList.of(); + when(groupDAO.getCapacityGroups(testEnvBean.getEnv_id())) + .thenReturn(ImmutableList.of(new String(TEST_CLUSTER_NAME))); + + environHandler.updateGroups(testEnvBean, groups, TEST_OPERATOR); + + verify(groupDAO, never()).addGroupCapacity(anyString(), anyString()); + verify(groupDAO, never()).removeGroupCapacity(anyString(), anyString()); + } + + @Test + void updateGroups_replaceGroups() throws Exception { + ImmutableList oldGroups = ImmutableList.of(TEST_CLUSTER_NAME, "group1"); + when(groupDAO.getCapacityGroups(testEnvBean.getEnv_id())).thenReturn(oldGroups); + + environHandler.updateGroups(testEnvBean, ImmutableList.of("group2"), TEST_OPERATOR); + + ArgumentCaptor removedGroupCaptor = ArgumentCaptor.forClass(String.class); + ArgumentCaptor addedGroupCaptor = ArgumentCaptor.forClass(String.class); + + verify(groupDAO, times(1)).addGroupCapacity(anyString(), addedGroupCaptor.capture()); + verify(groupDAO, times(1)).removeGroupCapacity(anyString(), removedGroupCaptor.capture()); + + List addedGroups = addedGroupCaptor.getAllValues(); + assertEquals("group2", addedGroups.get(0)); + + List removedGroups = removedGroupCaptor.getAllValues(); + assertEquals("group1", removedGroups.get(0)); + } }