Skip to content

Commit

Permalink
Revert soft failing change to throw UnauthorizedTaggingOperation erro…
Browse files Browse the repository at this point in the history
…r code (#467)
  • Loading branch information
moataz-mhmd authored Oct 2, 2023
1 parent 71ea68c commit dc074b9
Show file tree
Hide file tree
Showing 40 changed files with 588 additions and 262 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,34 +22,34 @@
import software.amazon.awssdk.services.rds.model.Tag;
import software.amazon.cloudformation.proxy.AmazonWebServicesClientProxy;
import software.amazon.cloudformation.proxy.HandlerErrorCode;
import software.amazon.cloudformation.proxy.OperationStatus;
import software.amazon.cloudformation.proxy.ProgressEvent;
import software.amazon.cloudformation.proxy.ProxyClient;
import software.amazon.rds.common.error.ErrorCode;
import software.amazon.rds.common.error.ErrorRuleSet;
import software.amazon.rds.common.error.ErrorStatus;

public final class Tagging {
public static final ErrorRuleSet IGNORE_LIST_TAGS_PERMISSION_DENIED_ERROR_RULE_SET = ErrorRuleSet
public static final ErrorRuleSet SOFT_FAIL_IN_PROGRESS_TAGGING_ERROR_RULE_SET = ErrorRuleSet
.extend(ErrorRuleSet.EMPTY_RULE_SET)
.withErrorCodes(ErrorStatus.ignore(),
.withErrorCodes(ErrorStatus.ignore(OperationStatus.IN_PROGRESS),
ErrorCode.AccessDenied,
ErrorCode.AccessDeniedException)
.build();

public static final ErrorRuleSet STACK_TAGS_ERROR_RULE_SET = ErrorRuleSet
public static final ErrorRuleSet SOFT_FAIL_TAG_ERROR_RULE_SET = ErrorRuleSet
.extend(ErrorRuleSet.EMPTY_RULE_SET)
.withErrorCodes(ErrorStatus.failWith(HandlerErrorCode.UnauthorizedTaggingOperation),
.withErrorCodes(ErrorStatus.ignore(),
ErrorCode.AccessDenied,
ErrorCode.AccessDeniedException
).build();

public static final ErrorRuleSet RESOURCE_TAG_ERROR_RULE_SET = ErrorRuleSet
public static final ErrorRuleSet HARD_FAIL_TAG_ERROR_RULE_SET = ErrorRuleSet
.extend(ErrorRuleSet.EMPTY_RULE_SET)
.withErrorCodes(ErrorStatus.failWith(HandlerErrorCode.AccessDenied),
ErrorCode.AccessDenied,
ErrorCode.AccessDeniedException
).build();
public static final String RDS_ADD_TAGS_TO_RESOURCE_ACTION = "rds:AddTagsToResource";

public static TagSet exclude(final TagSet from, final TagSet what) {
final Set<Tag> systemTags = new LinkedHashSet<>(from.getSystemTags());
Expand Down Expand Up @@ -184,52 +184,56 @@ private static RemoveTagsFromResourceRequest removeTagsFromResourceRequest(
.build();
}

public static ErrorRuleSet getUpdateTagsAccessDeniedRuleSet(
public static ErrorRuleSet bestEffortErrorRuleSet(
final TagSet tagsToAdd,
final TagSet tagsToRemove
) {
return getUpdateTagsAccessDeniedRuleSet(tagsToAdd, tagsToRemove, STACK_TAGS_ERROR_RULE_SET, RESOURCE_TAG_ERROR_RULE_SET);
return bestEffortErrorRuleSet(tagsToAdd, tagsToRemove, SOFT_FAIL_TAG_ERROR_RULE_SET, HARD_FAIL_TAG_ERROR_RULE_SET);
}

public static ErrorRuleSet getUpdateTagsAccessDeniedRuleSet(
public static ErrorRuleSet bestEffortErrorRuleSet(
final TagSet tagsToAdd,
final TagSet tagsToRemove,
final ErrorRuleSet stackTagsErrorRuleSet,
final ErrorRuleSet resourceTagsErrorRuleSet
final ErrorRuleSet softFailErrorRuleSet,
final ErrorRuleSet hardFailErrorRuleSet
) {
/* If the tagging operation comes across an AccessDenied error, we will throw an UnauthorizedTaggingOperation
errorCode for stack level tags. For Resource tags, if they are included, we will throw an AccessDenied error.
This is done to ensure backward compatibility. */
// Only soft fail if the customer provided no resource-level tags
if (tagsToAdd.getResourceTags().isEmpty() && tagsToRemove.getResourceTags().isEmpty()) {
return stackTagsErrorRuleSet;
return softFailErrorRuleSet;
}
return resourceTagsErrorRuleSet;
return hardFailErrorRuleSet;
}

public static <M, C extends TaggingContext.Provider> ProgressEvent<M, C> createWithTaggingFallback(
public static <M, C extends TaggingContext.Provider> ProgressEvent<M, C> safeCreate(
final AmazonWebServicesClientProxy proxy,
final ProxyClient<RdsClient> rdsProxyClient,
final HandlerMethod<M, C> handlerMethod,
final ProgressEvent<M, C> progress,
final Tagging.TagSet allTags
) {
final ProgressEvent<M, C> allTagsResult = handlerMethod.invoke(proxy, rdsProxyClient, progress, allTags);
if (allTagsResult.isFailed()) {
if (isUnauthorizedTaggingFailure(allTagsResult, allTags)) {
allTagsResult.setErrorCode(HandlerErrorCode.UnauthorizedTaggingOperation);
return progress.then(p -> {
final C context = p.getCallbackContext();
if (context.getTaggingContext().isSoftFailTags()) {
return p;
}
final ProgressEvent<M, C> allTagsResult = handlerMethod.invoke(proxy, rdsProxyClient, p, allTags);
if (allTagsResult.isFailed()) {
if (HandlerErrorCode.AccessDenied.equals(allTagsResult.getErrorCode())) {
context.getTaggingContext().setSoftFailTags(true);
return ProgressEvent.progress(allTagsResult.getResourceModel(), context);
}
return allTagsResult;
}
allTagsResult.getCallbackContext().getTaggingContext().setAddTagsComplete(true);
return allTagsResult;
}
allTagsResult.getCallbackContext().getTaggingContext().setAddTagsComplete(true);
return allTagsResult;
}

private static <M, C extends TaggingContext.Provider> boolean isUnauthorizedTaggingFailure(final ProgressEvent<M, C> allTagsResult,
final TagSet allTags) {
return HandlerErrorCode.AccessDenied.equals(allTagsResult.getErrorCode()) &&
allTags.getResourceTags().isEmpty() &&
allTagsResult.getMessage() != null &&
allTagsResult.getMessage().contains(RDS_ADD_TAGS_TO_RESOURCE_ACTION);
}).then(p -> {
final C context = p.getCallbackContext();
if (!context.getTaggingContext().isSoftFailTags()) {
return p;
}
final Tagging.TagSet systemTags = Tagging.TagSet.builder().systemTags(allTags.getSystemTags()).build();
return handlerMethod.invoke(proxy, rdsProxyClient, p, systemTags);
});
}

private static void addToMapIfAbsent(Map<String, Tag> allTags, Collection<Tag> tags) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
@lombok.ToString
@lombok.EqualsAndHashCode
public class TaggingContext {
private boolean softFailTags;
private boolean addTagsComplete;

public interface Provider {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.Mockito;

Expand All @@ -39,14 +40,15 @@
import software.amazon.cloudformation.proxy.ResourceHandlerRequest;
import software.amazon.rds.common.error.ErrorRuleSet;
import software.amazon.rds.common.error.ErrorStatus;
import software.amazon.rds.common.error.IgnoreErrorStatus;
import software.amazon.rds.common.error.UnexpectedErrorStatus;
import software.amazon.rds.common.logging.LoggingProxyClient;
import software.amazon.rds.common.logging.RequestLogger;
import software.amazon.rds.common.printer.FilteredJsonPrinter;

public class TaggingTest extends ProxyClientTestBase {

private final static String FAILED_MSG = "User is not authorized to do rds:AddTagsToResource action";
private final static String FAILED_MSG = "Test message: failed";

final Set<Tag> SYSTEM_TAGS = Stream.of(
Tag.builder().key("system-tag-key-1").value("system-tag-value-1").build(),
Expand Down Expand Up @@ -143,9 +145,9 @@ void test_SoftFailErrorRuleSet_AwsServiceException_AccessDenied() {
.errorCode("AccessDenied")
.build())
.build();
final ErrorRuleSet ruleSet = Tagging.STACK_TAGS_ERROR_RULE_SET;
final ErrorRuleSet ruleSet = Tagging.SOFT_FAIL_TAG_ERROR_RULE_SET;
final ErrorStatus status = ruleSet.handle(exception);
assertThat(status).isInstanceOf(ErrorStatus.class);
assertThat(status).isInstanceOf(IgnoreErrorStatus.class);
}

@Test
Expand All @@ -155,15 +157,15 @@ void test_SoftFailErrorRuleSet_AwsServiceException_OtherCode() {
.errorCode("InternalFailure")
.build())
.build();
final ErrorRuleSet ruleSet = Tagging.STACK_TAGS_ERROR_RULE_SET;
final ErrorRuleSet ruleSet = Tagging.SOFT_FAIL_TAG_ERROR_RULE_SET;
final ErrorStatus status = ruleSet.handle(exception);
assertThat(status).isInstanceOf(UnexpectedErrorStatus.class);
}

@Test
void test_SoftFailErrorRuleSet_OtherException() {
final Exception exception = new RuntimeException("test exception");
final ErrorRuleSet ruleSet = Tagging.STACK_TAGS_ERROR_RULE_SET;
final ErrorRuleSet ruleSet = Tagging.SOFT_FAIL_TAG_ERROR_RULE_SET;
final ErrorStatus status = ruleSet.handle(exception);
assertThat(status).isInstanceOf(UnexpectedErrorStatus.class);
}
Expand Down Expand Up @@ -270,7 +272,7 @@ void test_exclude() {

@Test
void test_bestEffortErrorRuleSet_emptyResourceTags() {
final ErrorRuleSet errorRuleSet = Tagging.getUpdateTagsAccessDeniedRuleSet(
final ErrorRuleSet errorRuleSet = Tagging.bestEffortErrorRuleSet(
Tagging.TagSet.builder()
.stackTags(Collections.singleton(Tag.builder().build()))
.stackTags(Collections.singleton(Tag.builder().build()))
Expand All @@ -281,12 +283,12 @@ void test_bestEffortErrorRuleSet_emptyResourceTags() {
.build()
);

assertThat(errorRuleSet).isEqualTo(Tagging.STACK_TAGS_ERROR_RULE_SET);
assertThat(errorRuleSet).isEqualTo(Tagging.SOFT_FAIL_TAG_ERROR_RULE_SET);
}

@Test
void test_bestEffortErrorRuleSet_nonEmptyResourceTags() {
assertThat(Tagging.getUpdateTagsAccessDeniedRuleSet(
assertThat(Tagging.bestEffortErrorRuleSet(
Tagging.TagSet.builder()
.stackTags(Collections.singleton(Tag.builder().build()))
.stackTags(Collections.singleton(Tag.builder().build()))
Expand All @@ -297,9 +299,9 @@ void test_bestEffortErrorRuleSet_nonEmptyResourceTags() {
.stackTags(Collections.singleton(Tag.builder().build()))
.resourceTags(Collections.singleton(Tag.builder().build()))
.build()
)).isEqualTo(Tagging.RESOURCE_TAG_ERROR_RULE_SET);
)).isEqualTo(Tagging.HARD_FAIL_TAG_ERROR_RULE_SET);

assertThat(Tagging.getUpdateTagsAccessDeniedRuleSet(
assertThat(Tagging.bestEffortErrorRuleSet(
Tagging.TagSet.builder()
.stackTags(Collections.singleton(Tag.builder().build()))
.stackTags(Collections.singleton(Tag.builder().build()))
Expand All @@ -310,9 +312,9 @@ void test_bestEffortErrorRuleSet_nonEmptyResourceTags() {
.stackTags(Collections.singleton(Tag.builder().build()))
.resourceTags(Collections.singleton(Tag.builder().build()))
.build()
)).isEqualTo(Tagging.RESOURCE_TAG_ERROR_RULE_SET);
)).isEqualTo(Tagging.HARD_FAIL_TAG_ERROR_RULE_SET);

assertThat(Tagging.getUpdateTagsAccessDeniedRuleSet(
assertThat(Tagging.bestEffortErrorRuleSet(
Tagging.TagSet.builder()
.stackTags(Collections.singleton(Tag.builder().build()))
.stackTags(Collections.singleton(Tag.builder().build()))
Expand All @@ -323,7 +325,7 @@ void test_bestEffortErrorRuleSet_nonEmptyResourceTags() {
.stackTags(Collections.singleton(Tag.builder().build()))
.resourceTags(Collections.emptySet())
.build()
)).isEqualTo(Tagging.RESOURCE_TAG_ERROR_RULE_SET);
)).isEqualTo(Tagging.HARD_FAIL_TAG_ERROR_RULE_SET);
}

@Test
Expand All @@ -341,7 +343,7 @@ public void safeCreate_allTagsSuccess() {
Mockito.when(handlerMethod.invoke(Mockito.any(), Mockito.any(), Mockito.any(ProgressEvent.class), Mockito.any(Tagging.TagSet.class)))
.thenReturn(ProgressEvent.success(null, progress.getCallbackContext()));

final ProgressEvent<Void, CommonsTest.TaggingCallbackContext> result = Tagging.createWithTaggingFallback(null, null, handlerMethod, progress, allTags);
final ProgressEvent<Void, CommonsTest.TaggingCallbackContext> result = Tagging.safeCreate(null, null, handlerMethod, progress, allTags);

Assertions.assertThat(result.isSuccess()).isTrue();
Assertions.assertThat(result.getCallbackContext().getTaggingContext().isAddTagsComplete()).isTrue();
Expand All @@ -357,18 +359,29 @@ public void safeCreate_allTagsFailAccessDenied() {
final Tagging.TagSet allTags = Tagging.TagSet.builder()
.systemTags(SYSTEM_TAGS)
.stackTags(STACK_TAGS)
.resourceTags(RESOURCE_TAGS)
.build();

final HandlerMethod<Void, CommonsTest.TaggingCallbackContext> handlerMethod = Mockito.mock(HandlerMethod.class);

Mockito.when(handlerMethod.invoke(Mockito.any(), Mockito.any(), Mockito.any(ProgressEvent.class), Mockito.any(Tagging.TagSet.class)))
.thenReturn(ProgressEvent.failed(null, progress.getCallbackContext(), HandlerErrorCode.AccessDenied, FAILED_MSG));
.thenReturn(ProgressEvent.failed(null, progress.getCallbackContext(), HandlerErrorCode.AccessDenied, FAILED_MSG))
.thenReturn(ProgressEvent.success(null, progress.getCallbackContext()));

final ProgressEvent<Void, CommonsTest.TaggingCallbackContext> result = Tagging.createWithTaggingFallback(null, null, handlerMethod, progress, allTags);
final ProgressEvent<Void, CommonsTest.TaggingCallbackContext> result = Tagging.safeCreate(null, null, handlerMethod, progress, allTags);

Assertions.assertThat(result.isFailed()).isTrue();
Assertions.assertThat(result.getErrorCode()).isEqualTo(HandlerErrorCode.UnauthorizedTaggingOperation);
Assertions.assertThat(result.isSuccess()).isTrue();
Assertions.assertThat(result.getCallbackContext().getTaggingContext().isSoftFailTags()).isTrue();
Assertions.assertThat(result.getCallbackContext().getTaggingContext().isAddTagsComplete()).isFalse();

ArgumentCaptor<Tagging.TagSet> captor = ArgumentCaptor.forClass(Tagging.TagSet.class);
Mockito.verify(handlerMethod, Mockito.times(2)).invoke(Mockito.any(), Mockito.any(), Mockito.any(ProgressEvent.class), captor.capture());

final Tagging.TagSet tagSetInvoke1 = captor.getAllValues().get(0);
final Tagging.TagSet tagSetInvoke2 = captor.getAllValues().get(1);

Assertions.assertThat(tagSetInvoke1).isEqualTo(allTags);
Assertions.assertThat(tagSetInvoke2).isEqualTo(Tagging.TagSet.builder().systemTags(SYSTEM_TAGS).build());
}

@Test
Expand All @@ -386,10 +399,11 @@ public void safeCreate_allTagsFailGeneric() {
Mockito.when(handlerMethod.invoke(Mockito.any(), Mockito.any(), Mockito.any(ProgressEvent.class), Mockito.any(Tagging.TagSet.class)))
.thenReturn(ProgressEvent.failed(null, progress.getCallbackContext(), HandlerErrorCode.InternalFailure, FAILED_MSG));

final ProgressEvent<Void, CommonsTest.TaggingCallbackContext> result = Tagging.createWithTaggingFallback(null, null, handlerMethod, progress, allTags);
final ProgressEvent<Void, CommonsTest.TaggingCallbackContext> result = Tagging.safeCreate(null, null, handlerMethod, progress, allTags);

Assertions.assertThat(result.isFailed()).isTrue();
Assertions.assertThat(result.getCallbackContext().getTaggingContext().isAddTagsComplete()).isFalse();
Assertions.assertThat(result.getCallbackContext().getTaggingContext().isSoftFailTags()).isFalse();

Mockito.verify(handlerMethod, Mockito.times(1))
.invoke(Mockito.any(), Mockito.any(), Mockito.any(ProgressEvent.class), Mockito.any(Tagging.TagSet.class));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,11 @@ private ProgressEvent<ResourceModel, CallbackContext> getTaggingErrorRuleSet(fin
progress,
exception,
DEFAULT_CUSTOM_DB_ENGINE_VERSION_ERROR_RULE_SET.extendWith(
Tagging.getUpdateTagsAccessDeniedRuleSet(
Tagging.bestEffortErrorRuleSet(
tagsToAdd,
tagsToRemove
tagsToRemove,
Tagging.SOFT_FAIL_IN_PROGRESS_TAGGING_ERROR_RULE_SET,
Tagging.HARD_FAIL_TAG_ERROR_RULE_SET
)
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ private ProgressEvent<ResourceModel, CallbackContext> safeCreateCustomEngineVers
final ProxyClient<RdsClient> proxyClient,
final ProgressEvent<ResourceModel, CallbackContext> progress,
final Tagging.TagSet allTags) {
return Tagging.createWithTaggingFallback(proxy, proxyClient, this::createCustomEngineVersion, progress, allTags)
return Tagging.safeCreate(proxy, proxyClient, this::createCustomEngineVersion, progress, allTags)
.then(p -> Commons.execOnce(p, () -> {
final Tagging.TagSet extraTags = Tagging.TagSet.builder()
.stackTags(allTags.getStackTags())
Expand Down
Loading

0 comments on commit dc074b9

Please sign in to comment.