Skip to content

Commit

Permalink
[Common][All] Remove silent failure mechanism and throw unauthorized …
Browse files Browse the repository at this point in the history
…tagging exception.
  • Loading branch information
moataz-mhmd committed Nov 9, 2023
1 parent f1a0c9f commit 72bd7ac
Show file tree
Hide file tree
Showing 37 changed files with 224 additions and 546 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,34 +22,34 @@
import software.amazon.awssdk.services.rds.model.Tag;
import software.amazon.cloudformation.proxy.AmazonWebServicesClientProxy;
import software.amazon.cloudformation.proxy.HandlerErrorCode;
import software.amazon.cloudformation.proxy.OperationStatus;
import software.amazon.cloudformation.proxy.ProgressEvent;
import software.amazon.cloudformation.proxy.ProxyClient;
import software.amazon.rds.common.error.ErrorCode;
import software.amazon.rds.common.error.ErrorRuleSet;
import software.amazon.rds.common.error.ErrorStatus;

public final class Tagging {
public static final ErrorRuleSet SOFT_FAIL_IN_PROGRESS_TAGGING_ERROR_RULE_SET = ErrorRuleSet
public static final ErrorRuleSet IGNORE_LIST_TAGS_PERMISSION_DENIED_ERROR_RULE_SET = ErrorRuleSet
.extend(ErrorRuleSet.EMPTY_RULE_SET)
.withErrorCodes(ErrorStatus.ignore(OperationStatus.IN_PROGRESS),
.withErrorCodes(ErrorStatus.ignore(),
ErrorCode.AccessDenied,
ErrorCode.AccessDeniedException)
.build();

public static final ErrorRuleSet SOFT_FAIL_TAG_ERROR_RULE_SET = ErrorRuleSet
public static final ErrorRuleSet STACK_TAGS_ERROR_RULE_SET = ErrorRuleSet
.extend(ErrorRuleSet.EMPTY_RULE_SET)
.withErrorCodes(ErrorStatus.ignore(),
.withErrorCodes(ErrorStatus.failWith(HandlerErrorCode.UnauthorizedTaggingOperation),
ErrorCode.AccessDenied,
ErrorCode.AccessDeniedException
).build();

public static final ErrorRuleSet HARD_FAIL_TAG_ERROR_RULE_SET = ErrorRuleSet
public static final ErrorRuleSet RESOURCE_TAG_ERROR_RULE_SET = ErrorRuleSet
.extend(ErrorRuleSet.EMPTY_RULE_SET)
.withErrorCodes(ErrorStatus.failWith(HandlerErrorCode.AccessDenied),
ErrorCode.AccessDenied,
ErrorCode.AccessDeniedException
).build();
public static final String RDS_ADD_TAGS_TO_RESOURCE_ACTION = "rds:AddTagsToResource";

public static TagSet exclude(final TagSet from, final TagSet what) {
final Set<Tag> systemTags = new LinkedHashSet<>(from.getSystemTags());
Expand Down Expand Up @@ -194,56 +194,52 @@ private static RemoveTagsFromResourceRequest removeTagsFromResourceRequest(
.build();
}

public static ErrorRuleSet bestEffortErrorRuleSet(
public static ErrorRuleSet getUpdateTagsAccessDeniedRuleSet(
final TagSet tagsToAdd,
final TagSet tagsToRemove
) {
return bestEffortErrorRuleSet(tagsToAdd, tagsToRemove, SOFT_FAIL_TAG_ERROR_RULE_SET, HARD_FAIL_TAG_ERROR_RULE_SET);
return getUpdateTagsAccessDeniedRuleSet(tagsToAdd, tagsToRemove, STACK_TAGS_ERROR_RULE_SET, RESOURCE_TAG_ERROR_RULE_SET);
}

public static ErrorRuleSet bestEffortErrorRuleSet(
public static ErrorRuleSet getUpdateTagsAccessDeniedRuleSet(
final TagSet tagsToAdd,
final TagSet tagsToRemove,
final ErrorRuleSet softFailErrorRuleSet,
final ErrorRuleSet hardFailErrorRuleSet
final ErrorRuleSet stackTagsErrorRuleSet,
final ErrorRuleSet resourceTagsErrorRuleSet
) {
// Only soft fail if the customer provided no resource-level tags
/* If the tagging operation comes across an AccessDenied error, we will throw an UnauthorizedTaggingOperation
errorCode for stack level tags. For Resource tags, if they are included, we will throw an AccessDenied error.
This is done to ensure backward compatibility. */
if (tagsToAdd.getResourceTags().isEmpty() && tagsToRemove.getResourceTags().isEmpty()) {
return softFailErrorRuleSet;
return stackTagsErrorRuleSet;
}
return hardFailErrorRuleSet;
return resourceTagsErrorRuleSet;
}

public static <M, C extends TaggingContext.Provider> ProgressEvent<M, C> safeCreate(
public static <M, C extends TaggingContext.Provider> ProgressEvent<M, C> createWithTaggingFallback(
final AmazonWebServicesClientProxy proxy,
final ProxyClient<RdsClient> rdsProxyClient,
final HandlerMethod<M, C> handlerMethod,
final ProgressEvent<M, C> progress,
final Tagging.TagSet allTags
) {
return progress.then(p -> {
final C context = p.getCallbackContext();
if (context.getTaggingContext().isSoftFailTags()) {
return p;
final ProgressEvent<M, C> allTagsResult = handlerMethod.invoke(proxy, rdsProxyClient, progress, allTags);
if (allTagsResult.isFailed()) {
if (isUnauthorizedTaggingFailure(allTagsResult, allTags)) {
allTagsResult.setErrorCode(HandlerErrorCode.UnauthorizedTaggingOperation);
}
final ProgressEvent<M, C> allTagsResult = handlerMethod.invoke(proxy, rdsProxyClient, p, allTags);
if (allTagsResult.isFailed()) {
if (HandlerErrorCode.AccessDenied.equals(allTagsResult.getErrorCode())) {
context.getTaggingContext().setSoftFailTags(true);
return ProgressEvent.progress(allTagsResult.getResourceModel(), context);
}
return allTagsResult;
}
allTagsResult.getCallbackContext().getTaggingContext().setAddTagsComplete(true);
return allTagsResult;
}).then(p -> {
final C context = p.getCallbackContext();
if (!context.getTaggingContext().isSoftFailTags()) {
return p;
}
final Tagging.TagSet systemTags = Tagging.TagSet.builder().systemTags(allTags.getSystemTags()).build();
return handlerMethod.invoke(proxy, rdsProxyClient, p, systemTags);
});
}
allTagsResult.getCallbackContext().getTaggingContext().setAddTagsComplete(true);
return allTagsResult;
}

private static <M, C extends TaggingContext.Provider> boolean isUnauthorizedTaggingFailure(final ProgressEvent<M, C> allTagsResult,
final TagSet allTags) {
return HandlerErrorCode.AccessDenied.equals(allTagsResult.getErrorCode()) &&
allTags.getResourceTags().isEmpty() &&
allTagsResult.getMessage() != null &&
allTagsResult.getMessage().contains(RDS_ADD_TAGS_TO_RESOURCE_ACTION);
}

private static void addToMapIfAbsent(Map<String, Tag> allTags, Collection<Tag> tags) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
@lombok.ToString
@lombok.EqualsAndHashCode
public class TaggingContext {
private boolean softFailTags;
private boolean addTagsComplete;

public interface Provider {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.Mockito;

Expand All @@ -40,15 +39,14 @@
import software.amazon.cloudformation.proxy.ResourceHandlerRequest;
import software.amazon.rds.common.error.ErrorRuleSet;
import software.amazon.rds.common.error.ErrorStatus;
import software.amazon.rds.common.error.IgnoreErrorStatus;
import software.amazon.rds.common.error.UnexpectedErrorStatus;
import software.amazon.rds.common.logging.LoggingProxyClient;
import software.amazon.rds.common.logging.RequestLogger;
import software.amazon.rds.common.printer.FilteredJsonPrinter;

public class TaggingTest extends ProxyClientTestBase {

private final static String FAILED_MSG = "Test message: failed";
private final static String FAILED_MSG = "User is not authorized to do rds:AddTagsToResource action";

final Set<Tag> SYSTEM_TAGS = Stream.of(
Tag.builder().key("system-tag-key-1").value("system-tag-value-1").build(),
Expand Down Expand Up @@ -145,9 +143,9 @@ void test_SoftFailErrorRuleSet_AwsServiceException_AccessDenied() {
.errorCode("AccessDenied")
.build())
.build();
final ErrorRuleSet ruleSet = Tagging.SOFT_FAIL_TAG_ERROR_RULE_SET;
final ErrorRuleSet ruleSet = Tagging.STACK_TAGS_ERROR_RULE_SET;
final ErrorStatus status = ruleSet.handle(exception);
assertThat(status).isInstanceOf(IgnoreErrorStatus.class);
assertThat(status).isInstanceOf(ErrorStatus.class);
}

@Test
Expand All @@ -157,15 +155,15 @@ void test_SoftFailErrorRuleSet_AwsServiceException_OtherCode() {
.errorCode("InternalFailure")
.build())
.build();
final ErrorRuleSet ruleSet = Tagging.SOFT_FAIL_TAG_ERROR_RULE_SET;
final ErrorRuleSet ruleSet = Tagging.STACK_TAGS_ERROR_RULE_SET;
final ErrorStatus status = ruleSet.handle(exception);
assertThat(status).isInstanceOf(UnexpectedErrorStatus.class);
}

@Test
void test_SoftFailErrorRuleSet_OtherException() {
final Exception exception = new RuntimeException("test exception");
final ErrorRuleSet ruleSet = Tagging.SOFT_FAIL_TAG_ERROR_RULE_SET;
final ErrorRuleSet ruleSet = Tagging.STACK_TAGS_ERROR_RULE_SET;
final ErrorStatus status = ruleSet.handle(exception);
assertThat(status).isInstanceOf(UnexpectedErrorStatus.class);
}
Expand Down Expand Up @@ -272,7 +270,7 @@ void test_exclude() {

@Test
void test_bestEffortErrorRuleSet_emptyResourceTags() {
final ErrorRuleSet errorRuleSet = Tagging.bestEffortErrorRuleSet(
final ErrorRuleSet errorRuleSet = Tagging.getUpdateTagsAccessDeniedRuleSet(
Tagging.TagSet.builder()
.stackTags(Collections.singleton(Tag.builder().build()))
.stackTags(Collections.singleton(Tag.builder().build()))
Expand All @@ -283,12 +281,12 @@ void test_bestEffortErrorRuleSet_emptyResourceTags() {
.build()
);

assertThat(errorRuleSet).isEqualTo(Tagging.SOFT_FAIL_TAG_ERROR_RULE_SET);
assertThat(errorRuleSet).isEqualTo(Tagging.STACK_TAGS_ERROR_RULE_SET);
}

@Test
void test_bestEffortErrorRuleSet_nonEmptyResourceTags() {
assertThat(Tagging.bestEffortErrorRuleSet(
assertThat(Tagging.getUpdateTagsAccessDeniedRuleSet(
Tagging.TagSet.builder()
.stackTags(Collections.singleton(Tag.builder().build()))
.stackTags(Collections.singleton(Tag.builder().build()))
Expand All @@ -299,9 +297,9 @@ void test_bestEffortErrorRuleSet_nonEmptyResourceTags() {
.stackTags(Collections.singleton(Tag.builder().build()))
.resourceTags(Collections.singleton(Tag.builder().build()))
.build()
)).isEqualTo(Tagging.HARD_FAIL_TAG_ERROR_RULE_SET);
)).isEqualTo(Tagging.RESOURCE_TAG_ERROR_RULE_SET);

assertThat(Tagging.bestEffortErrorRuleSet(
assertThat(Tagging.getUpdateTagsAccessDeniedRuleSet(
Tagging.TagSet.builder()
.stackTags(Collections.singleton(Tag.builder().build()))
.stackTags(Collections.singleton(Tag.builder().build()))
Expand All @@ -312,9 +310,9 @@ void test_bestEffortErrorRuleSet_nonEmptyResourceTags() {
.stackTags(Collections.singleton(Tag.builder().build()))
.resourceTags(Collections.singleton(Tag.builder().build()))
.build()
)).isEqualTo(Tagging.HARD_FAIL_TAG_ERROR_RULE_SET);
)).isEqualTo(Tagging.RESOURCE_TAG_ERROR_RULE_SET);

assertThat(Tagging.bestEffortErrorRuleSet(
assertThat(Tagging.getUpdateTagsAccessDeniedRuleSet(
Tagging.TagSet.builder()
.stackTags(Collections.singleton(Tag.builder().build()))
.stackTags(Collections.singleton(Tag.builder().build()))
Expand All @@ -325,7 +323,7 @@ void test_bestEffortErrorRuleSet_nonEmptyResourceTags() {
.stackTags(Collections.singleton(Tag.builder().build()))
.resourceTags(Collections.emptySet())
.build()
)).isEqualTo(Tagging.HARD_FAIL_TAG_ERROR_RULE_SET);
)).isEqualTo(Tagging.RESOURCE_TAG_ERROR_RULE_SET);
}

@Test
Expand All @@ -343,7 +341,7 @@ public void safeCreate_allTagsSuccess() {
Mockito.when(handlerMethod.invoke(Mockito.any(), Mockito.any(), Mockito.any(ProgressEvent.class), Mockito.any(Tagging.TagSet.class)))
.thenReturn(ProgressEvent.success(null, progress.getCallbackContext()));

final ProgressEvent<Void, CommonsTest.TaggingCallbackContext> result = Tagging.safeCreate(null, null, handlerMethod, progress, allTags);
final ProgressEvent<Void, CommonsTest.TaggingCallbackContext> result = Tagging.createWithTaggingFallback(null, null, handlerMethod, progress, allTags);

Assertions.assertThat(result.isSuccess()).isTrue();
Assertions.assertThat(result.getCallbackContext().getTaggingContext().isAddTagsComplete()).isTrue();
Expand All @@ -359,29 +357,18 @@ public void safeCreate_allTagsFailAccessDenied() {
final Tagging.TagSet allTags = Tagging.TagSet.builder()
.systemTags(SYSTEM_TAGS)
.stackTags(STACK_TAGS)
.resourceTags(RESOURCE_TAGS)
.build();

final HandlerMethod<Void, CommonsTest.TaggingCallbackContext> handlerMethod = Mockito.mock(HandlerMethod.class);

Mockito.when(handlerMethod.invoke(Mockito.any(), Mockito.any(), Mockito.any(ProgressEvent.class), Mockito.any(Tagging.TagSet.class)))
.thenReturn(ProgressEvent.failed(null, progress.getCallbackContext(), HandlerErrorCode.AccessDenied, FAILED_MSG))
.thenReturn(ProgressEvent.success(null, progress.getCallbackContext()));
.thenReturn(ProgressEvent.failed(null, progress.getCallbackContext(), HandlerErrorCode.AccessDenied, FAILED_MSG));

final ProgressEvent<Void, CommonsTest.TaggingCallbackContext> result = Tagging.safeCreate(null, null, handlerMethod, progress, allTags);
final ProgressEvent<Void, CommonsTest.TaggingCallbackContext> result = Tagging.createWithTaggingFallback(null, null, handlerMethod, progress, allTags);

Assertions.assertThat(result.isSuccess()).isTrue();
Assertions.assertThat(result.getCallbackContext().getTaggingContext().isSoftFailTags()).isTrue();
Assertions.assertThat(result.isFailed()).isTrue();
Assertions.assertThat(result.getErrorCode()).isEqualTo(HandlerErrorCode.UnauthorizedTaggingOperation);
Assertions.assertThat(result.getCallbackContext().getTaggingContext().isAddTagsComplete()).isFalse();

ArgumentCaptor<Tagging.TagSet> captor = ArgumentCaptor.forClass(Tagging.TagSet.class);
Mockito.verify(handlerMethod, Mockito.times(2)).invoke(Mockito.any(), Mockito.any(), Mockito.any(ProgressEvent.class), captor.capture());

final Tagging.TagSet tagSetInvoke1 = captor.getAllValues().get(0);
final Tagging.TagSet tagSetInvoke2 = captor.getAllValues().get(1);

Assertions.assertThat(tagSetInvoke1).isEqualTo(allTags);
Assertions.assertThat(tagSetInvoke2).isEqualTo(Tagging.TagSet.builder().systemTags(SYSTEM_TAGS).build());
}

@Test
Expand All @@ -399,11 +386,10 @@ public void safeCreate_allTagsFailGeneric() {
Mockito.when(handlerMethod.invoke(Mockito.any(), Mockito.any(), Mockito.any(ProgressEvent.class), Mockito.any(Tagging.TagSet.class)))
.thenReturn(ProgressEvent.failed(null, progress.getCallbackContext(), HandlerErrorCode.InternalFailure, FAILED_MSG));

final ProgressEvent<Void, CommonsTest.TaggingCallbackContext> result = Tagging.safeCreate(null, null, handlerMethod, progress, allTags);
final ProgressEvent<Void, CommonsTest.TaggingCallbackContext> result = Tagging.createWithTaggingFallback(null, null, handlerMethod, progress, allTags);

Assertions.assertThat(result.isFailed()).isTrue();
Assertions.assertThat(result.getCallbackContext().getTaggingContext().isAddTagsComplete()).isFalse();
Assertions.assertThat(result.getCallbackContext().getTaggingContext().isSoftFailTags()).isFalse();

Mockito.verify(handlerMethod, Mockito.times(1))
.invoke(Mockito.any(), Mockito.any(), Mockito.any(ProgressEvent.class), Mockito.any(Tagging.TagSet.class));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,9 @@ private ProgressEvent<ResourceModel, CallbackContext> getTaggingErrorRuleSet(fin
progress,
exception,
DEFAULT_CUSTOM_DB_ENGINE_VERSION_ERROR_RULE_SET.extendWith(
Tagging.bestEffortErrorRuleSet(
Tagging.getUpdateTagsAccessDeniedRuleSet(
tagsToAdd,
tagsToRemove,
Tagging.SOFT_FAIL_IN_PROGRESS_TAGGING_ERROR_RULE_SET,
Tagging.HARD_FAIL_TAG_ERROR_RULE_SET
tagsToRemove
)
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ private ProgressEvent<ResourceModel, CallbackContext> safeCreateCustomEngineVers
final ProxyClient<RdsClient> proxyClient,
final ProgressEvent<ResourceModel, CallbackContext> progress,
final Tagging.TagSet allTags) {
return Tagging.safeCreate(proxy, proxyClient, this::createCustomEngineVersion, progress, allTags)
return Tagging.createWithTaggingFallback(proxy, proxyClient, this::createCustomEngineVersion, progress, allTags)
.then(p -> Commons.execOnce(p, () -> {
final Tagging.TagSet extraTags = Tagging.TagSet.builder()
.stackTags(allTags.getStackTags())
Expand Down
Loading

0 comments on commit 72bd7ac

Please sign in to comment.