Skip to content

Commit

Permalink
Fix a group capacity update bug (#1698)
Browse files Browse the repository at this point in the history
equals() instead of ==
  • Loading branch information
tylerwowen committed Aug 16, 2024
1 parent 5de012e commit 0e62d1e
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -444,9 +444,8 @@ public void updateHosts(EnvironBean envBean, List<String> hosts, String operator
public void updateGroups(EnvironBean envBean, List<String> groups, String operator)
throws Exception {
// TODO need to check group env conflicts and reject if so
List<String> oldGroupList = groupDAO.getCapacityGroups(envBean.getEnv_id());
Set<String> oldGroups = new HashSet<>();
oldGroups.addAll(oldGroupList);
oldGroups.addAll(groupDAO.getCapacityGroups(envBean.getEnv_id()));
for (String group : groups) {
if (!oldGroups.contains(group)) {
groupDAO.addGroupCapacity(envBean.getEnv_id(), group);
Expand All @@ -455,7 +454,7 @@ public void updateGroups(EnvironBean envBean, List<String> groups, String operat
}
}
for (String group : oldGroups) {
if (group == envBean.getCluster_name()) {
if (group.equals(envBean.getCluster_name())) {
LOG.info("Skipping implicit group {}", group);
continue;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,20 @@
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Matchers.anyString;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.google.common.collect.ImmutableList;
import com.pinterest.deployservice.ServiceContext;
import com.pinterest.deployservice.bean.EnvType;
import com.pinterest.deployservice.bean.EnvironBean;
import com.pinterest.deployservice.bean.HostBean;
import com.pinterest.deployservice.bean.HostState;
import com.pinterest.deployservice.dao.AgentDAO;
import com.pinterest.deployservice.dao.EnvironDAO;
import com.pinterest.deployservice.dao.GroupDAO;
import com.pinterest.deployservice.dao.HostDAO;
import java.sql.SQLException;
import java.util.Arrays;
Expand All @@ -41,31 +44,38 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

class EnvironHandlerTest {
private static final String TEST_OPERATOR = "operator";
private static final String TEST_CLUSTER_NAME = "clusterName";
private static final String DEFAULT_HOST_ID = "hostId";

private EnvironHandler environHandler;

private HostDAO mockHostDAO;
private AgentDAO mockAgentDAO;
private EnvironDAO environDAO;
@Mock private HostDAO mockHostDAO;
@Mock private AgentDAO mockAgentDAO;
@Mock private EnvironDAO environDAO;
@Mock private GroupDAO groupDAO;
EnvironBean testEnvBean;
private List<String> hostIds = Arrays.asList("hostId1", "hostId2");

private ServiceContext createMockServiceContext() throws Exception {
mockHostDAO = mock(HostDAO.class);
mockAgentDAO = mock(AgentDAO.class);
environDAO = mock(EnvironDAO.class);

private ServiceContext createMockServiceContext() {
ServiceContext serviceContext = new ServiceContext();
serviceContext.setHostDAO(mockHostDAO);
serviceContext.setAgentDAO(mockAgentDAO);
serviceContext.setEnvironDAO(environDAO);
serviceContext.setGroupDAO(groupDAO);
return serviceContext;
}

@BeforeEach
void setUp() throws Exception {
void setUp() {
MockitoAnnotations.initMocks(this);
testEnvBean = new EnvironBean();
testEnvBean.setEnv_id("envId");
testEnvBean.setCluster_name(TEST_CLUSTER_NAME);
environHandler = new EnvironHandler(createMockServiceContext());
}

Expand All @@ -91,36 +101,31 @@ void stopServiceOnHost_withoutReplaceHost_hostBeanStateIsPendingTerminateNoRepla
@Test
void updateStage_type_enables_private_build() throws Exception {
ArgumentCaptor<EnvironBean> argument = ArgumentCaptor.forClass(EnvironBean.class);
EnvironBean envBean = new EnvironBean();
envBean.setStage_type(EnvType.DEV);
environHandler.createEnvStage(envBean, "Anonymous");
testEnvBean.setStage_type(EnvType.DEV);
environHandler.createEnvStage(testEnvBean, "Anonymous");
verify(environDAO).insert(argument.capture());
assertEquals(true, argument.getValue().getAllow_private_build());
}

@Test
void ensureHostsOwnedByEnv_noMainEnv() throws Exception {
void ensureHostsOwnedByEnv_noMainEnv() {
assertThrows(
NotFoundException.class,
() -> environHandler.ensureHostsOwnedByEnv(new EnvironBean(), hostIds));
}

@Test
void ensureHostsOwnedByEnv_differentMainEnv() throws Exception {
EnvironBean envBean = new EnvironBean();
envBean.setEnv_id("envId");
when(environDAO.getMainEnvByHostId(anyString())).thenReturn(envBean);
when(environDAO.getMainEnvByHostId(anyString())).thenReturn(testEnvBean);
assertThrows(
ForbiddenException.class,
() -> environHandler.ensureHostsOwnedByEnv(new EnvironBean(), hostIds));
}

@Test
void ensureHostsOwnedByEnv_sameMainEnv() throws Exception {
EnvironBean envBean = new EnvironBean();
envBean.setEnv_id("envId");
when(environDAO.getMainEnvByHostId(anyString())).thenReturn(envBean);
assertDoesNotThrow(() -> environHandler.ensureHostsOwnedByEnv(envBean, hostIds));
when(environDAO.getMainEnvByHostId(anyString())).thenReturn(testEnvBean);
assertDoesNotThrow(() -> environHandler.ensureHostsOwnedByEnv(testEnvBean, hostIds));
}

@Test
Expand All @@ -130,4 +135,56 @@ void ensureHostsOwnedByEnv_sqlException() throws Exception {
WebApplicationException.class,
() -> environHandler.ensureHostsOwnedByEnv(new EnvironBean(), hostIds));
}

@Test
void updateGroups_addNewGroups() throws Exception {
List<String> groups = ImmutableList.of("group1", "group2");
when(groupDAO.getCapacityGroups(testEnvBean.getEnv_id()))
.thenReturn(ImmutableList.of(TEST_CLUSTER_NAME));

environHandler.updateGroups(testEnvBean, groups, TEST_OPERATOR);

ArgumentCaptor<String> groupCaptor = ArgumentCaptor.forClass(String.class);

verify(groupDAO, times(2)).addGroupCapacity(anyString(), groupCaptor.capture());
verify(groupDAO, never()).removeGroupCapacity(anyString(), anyString());

List<String> capturedGroups = groupCaptor.getAllValues();

for (int i = 0; i < groups.size(); i++) {
assertEquals(groups.get(i), capturedGroups.get(i));
}
}

@Test
void updateGroups_addEmptyGroups() throws Exception {
List<String> groups = ImmutableList.of();
when(groupDAO.getCapacityGroups(testEnvBean.getEnv_id()))
.thenReturn(ImmutableList.of(new String(TEST_CLUSTER_NAME)));

environHandler.updateGroups(testEnvBean, groups, TEST_OPERATOR);

verify(groupDAO, never()).addGroupCapacity(anyString(), anyString());
verify(groupDAO, never()).removeGroupCapacity(anyString(), anyString());
}

@Test
void updateGroups_replaceGroups() throws Exception {
ImmutableList<String> oldGroups = ImmutableList.of(TEST_CLUSTER_NAME, "group1");
when(groupDAO.getCapacityGroups(testEnvBean.getEnv_id())).thenReturn(oldGroups);

environHandler.updateGroups(testEnvBean, ImmutableList.of("group2"), TEST_OPERATOR);

ArgumentCaptor<String> removedGroupCaptor = ArgumentCaptor.forClass(String.class);
ArgumentCaptor<String> addedGroupCaptor = ArgumentCaptor.forClass(String.class);

verify(groupDAO, times(1)).addGroupCapacity(anyString(), addedGroupCaptor.capture());
verify(groupDAO, times(1)).removeGroupCapacity(anyString(), removedGroupCaptor.capture());

List<String> addedGroups = addedGroupCaptor.getAllValues();
assertEquals("group2", addedGroups.get(0));

List<String> removedGroups = removedGroupCaptor.getAllValues();
assertEquals("group1", removedGroups.get(0));
}
}

0 comments on commit 0e62d1e

Please sign in to comment.