Skip to content

Commit

Permalink
[DBInstance] Use DBParameterGroup in CreateReadReplica for MySql engi…
Browse files Browse the repository at this point in the history
…ne (#469)

* [DBInstance] Use DBParameterGroup in CreateReadReplica for MySql engine
  • Loading branch information
moataz-mhmd authored Oct 27, 2023
1 parent f524650 commit 82e98c7
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1101,7 +1101,7 @@ protected ProgressEvent<ResourceModel, CallbackContext> stopAutomaticBackupRepli
final ProxyClient<RdsClient> sourceRegionClient,
final String region
) {
final ProxyClient<RdsClient> rdsClient = proxy.newProxy(() -> new RdsClientProvider().getClientForRegion(region));
final ProxyClient<RdsClient> rdsClient = new LoggingProxyClient<>(logger, proxy.newProxy(() -> new RdsClientProvider().getClientForRegion(region)));

return proxy.initiate("rds::stop-db-instance-automatic-backup-replication", rdsClient, progress.getResourceModel(), progress.getCallbackContext())
.translateToServiceRequest(resourceModel -> Translator.stopDbInstanceAutomatedBackupsReplicationRequest(dbInstanceArn))
Expand All @@ -1127,7 +1127,7 @@ protected ProgressEvent<ResourceModel, CallbackContext> startAutomaticBackupRepl
final ProxyClient<RdsClient> sourceRegionClient,
final String region
) {
final ProxyClient<RdsClient> rdsClient = proxy.newProxy(() -> new RdsClientProvider().getClientForRegion(region));
final ProxyClient<RdsClient> rdsClient = new LoggingProxyClient<>(logger, proxy.newProxy(() -> new RdsClientProvider().getClientForRegion(region)));

return proxy.initiate("rds::start-db-instance-automatic-backup-replication", rdsClient, progress.getResourceModel(), progress.getCallbackContext())
.translateToServiceRequest(resourceModel -> Translator.startDbInstanceAutomatedBackupsReplicationRequest(dbInstanceArn))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ public class CallbackContext extends StdCallbackContext implements TaggingContex
private boolean automaticBackupReplicationStopped;
private boolean automaticBackupReplicationStarted;
private String dbInstanceArn;
private String currentRegion;

private TaggingContext taggingContext;
private Map<String, Long> timestamps;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
import software.amazon.rds.common.handler.HandlerConfig;
import software.amazon.rds.common.handler.HandlerMethod;
import software.amazon.rds.common.handler.Tagging;
import software.amazon.rds.common.logging.LoggingProxyClient;
import software.amazon.rds.common.request.RequestValidationException;
import software.amazon.rds.common.request.ValidatedRequest;
import software.amazon.rds.common.request.Validations;
import software.amazon.rds.common.util.IdentifierFactory;
import software.amazon.rds.dbinstance.client.ApiVersion;
import software.amazon.rds.dbinstance.client.RdsClientProvider;
import software.amazon.rds.dbinstance.client.VersionedProxyClient;
import software.amazon.rds.dbinstance.util.ResourceModelHelper;

Expand Down Expand Up @@ -70,6 +72,7 @@ protected ProgressEvent<ResourceModel, CallbackContext> handleRequest(
final ResourceModel model = request.getDesiredResourceState();
final Collection<DBInstanceRole> desiredRoles = model.getAssociatedRoles();
final boolean isMultiAZ = BooleanUtils.isTrue(model.getMultiAZ());
callbackContext.setCurrentRegion(request.getRegion());

if (StringUtils.isNullOrEmpty(model.getDBInstanceIdentifier())) {
model.setDBInstanceIdentifier(instanceIdentifierFactory.newIdentifier()
Expand All @@ -89,7 +92,7 @@ protected ProgressEvent<ResourceModel, CallbackContext> handleRequest(
.then(progress -> {
if (StringUtils.isNullOrEmpty(progress.getResourceModel().getEngine())) {
try {
model.setEngine(fetchEngine(rdsProxyClient.defaultClient(), progress.getResourceModel()));
model.setEngine(fetchEngine(rdsProxyClient.defaultClient(), progress, proxy));
} catch (Exception e) {
return Commons.handleException(progress, e, DB_INSTANCE_FETCH_ENGINE_RULE_SET);
}
Expand Down Expand Up @@ -201,7 +204,12 @@ private HandlerMethod<ResourceModel, CallbackContext> safeAddTags(final HandlerM
return (proxy, rdsProxyClient, progress, tagSet) -> progress.then(p -> Tagging.safeCreate(proxy, rdsProxyClient, handlerMethod, progress, tagSet));
}

private String fetchEngine(final ProxyClient<RdsClient> client, final ResourceModel model) {
private String fetchEngine(final ProxyClient<RdsClient> client,
final ProgressEvent<ResourceModel, CallbackContext> progress,
final AmazonWebServicesClientProxy proxy) {
final ResourceModel model = progress.getResourceModel();
final String currentRegion = progress.getCallbackContext().getCurrentRegion();

if (ResourceModelHelper.isRestoreFromSnapshot(model)) {
return fetchDBSnapshot(client, model).engine();
}
Expand All @@ -210,10 +218,29 @@ private String fetchEngine(final ProxyClient<RdsClient> client, final ResourceMo
}

if (ResourceModelHelper.isDBInstanceReadReplica(model)) {
return fetchDBInstance(client, model.getSourceDBInstanceIdentifier()).engine();
final String sourceDBInstanceArn = model.getSourceDBInstanceIdentifier();
final String sourceDBInstanceIdOrArn = ResourceModelHelper.isValidArn(sourceDBInstanceArn) ?
ResourceModelHelper.getResourceNameFromArn(sourceDBInstanceArn) : sourceDBInstanceArn;
if (ResourceModelHelper.isCrossRegionDBInstanceReadReplica(model, currentRegion)) {
final String sourceRegion = ResourceModelHelper.getRegionFromArn(sourceDBInstanceArn);
final ProxyClient<RdsClient> sourceRegionClient = new LoggingProxyClient<>(logger,
proxy.newProxy(() -> new RdsClientProvider().getClientForRegion(sourceRegion)));
return fetchDBInstance(sourceRegionClient, sourceDBInstanceIdOrArn).engine();
} else {
return fetchDBInstance(client, sourceDBInstanceIdOrArn ).engine();
}
}
if (ResourceModelHelper.isDBClusterReadReplica(model)) {
return fetchDBCluster(client, model.getSourceDBClusterIdentifier()).engine();
final String sourceDBClusterArn = model.getSourceDBClusterIdentifier();
final String sourceDBClusterIdOrArn = ResourceModelHelper.isValidArn(sourceDBClusterArn) ?
ResourceModelHelper.getResourceNameFromArn(sourceDBClusterArn) : sourceDBClusterArn;
if (ResourceModelHelper.isCrossRegionDBClusterReadReplica(model, currentRegion)) {
final String sourceRegion = ResourceModelHelper.getRegionFromArn(sourceDBClusterArn);
final ProxyClient<RdsClient> sourceRegionClient = proxy.newProxy(() -> new RdsClientProvider().getClientForRegion(sourceRegion));
return fetchDBCluster(sourceRegionClient, sourceDBClusterIdOrArn).engine();
} else {
return fetchDBCluster(client, sourceDBClusterIdOrArn).engine();
}
}

if (ResourceModelHelper.isRestoreToPointInTime(model)) {
Expand Down Expand Up @@ -378,12 +405,13 @@ private ProgressEvent<ResourceModel, CallbackContext> createDbInstanceReadReplic
final ProgressEvent<ResourceModel, CallbackContext> progress,
final Tagging.TagSet tagSet
) {
final String currentRegion = progress.getCallbackContext().getCurrentRegion();
return proxy.initiate(
"rds::create-db-instance-read-replica",
rdsProxyClient,
progress.getResourceModel(),
progress.getCallbackContext()
).translateToServiceRequest(model -> Translator.createDbInstanceReadReplicaRequest(model, tagSet))
).translateToServiceRequest(model -> Translator.createDbInstanceReadReplicaRequest(model, tagSet, currentRegion))
.backoffDelay(config.getBackoff())
.makeServiceCall((createRequest, proxyInvocation) -> proxyInvocation.injectCredentialsAndInvokeV2(
createRequest,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

import com.amazonaws.arn.Arn;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.BooleanUtils;
import org.apache.commons.lang3.ObjectUtils;

import com.amazonaws.arn.Arn;
import com.google.common.annotations.VisibleForTesting;
import software.amazon.awssdk.services.ec2.model.DescribeSecurityGroupsRequest;
import software.amazon.awssdk.services.ec2.model.Filter;
Expand Down Expand Up @@ -108,7 +108,8 @@ public static DescribeDbClusterSnapshotsRequest describeDbClusterSnapshotsReques

public static CreateDbInstanceReadReplicaRequest createDbInstanceReadReplicaRequest(
final ResourceModel model,
final Tagging.TagSet tagSet
final Tagging.TagSet tagSet,
final String currentRegion
) {
final CreateDbInstanceReadReplicaRequest.Builder builder = CreateDbInstanceReadReplicaRequest.builder()
.autoMinorVersionUpgrade(model.getAutoMinorVersionUpgrade())
Expand Down Expand Up @@ -152,6 +153,11 @@ public static CreateDbInstanceReadReplicaRequest createDbInstanceReadReplicaRequ
builder.storageThroughput(model.getStorageThroughput());
builder.storageType(model.getStorageType());
}

if (ResourceModelHelper.isCrossRegionDBInstanceReadReplica(model, currentRegion) && ResourceModelHelper.isMySQL(model) ) {
builder.dbParameterGroupName(model.getDBParameterGroupName());
}

return builder.build();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
package software.amazon.rds.dbinstance.util;

import java.util.Optional;
import java.util.Set;

import org.apache.commons.lang3.BooleanUtils;

import com.amazonaws.arn.Arn;
import com.amazonaws.util.StringUtils;
import com.google.common.collect.ImmutableSet;
import org.apache.commons.lang3.BooleanUtils;
import software.amazon.awssdk.utils.CollectionUtils;
import software.amazon.rds.dbinstance.ResourceModel;

import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

public final class ResourceModelHelper {
private static final Set<String> SQLSERVER_ENGINES_WITH_MIRRORING = ImmutableSet.of(
"sqlserver-ee",
"sqlserver-se"
);
private static final String SQLSERVER_ENGINE = "sqlserver";
private static final String MYSQL_ENGINE_PREFIX = "mysql";
private static final String ORACLE_ENGINE_PREFIX = "oracle";

public static boolean shouldUpdateAfterCreate(final ResourceModel model) {
return (isReadReplica(model) ||
Expand Down Expand Up @@ -48,6 +49,10 @@ public static boolean isSqlServer(final ResourceModel model) {
return engine == null || engine.contains(SQLSERVER_ENGINE);
}

public static boolean isMySQL(final ResourceModel model) {
final String engine = model.getEngine();
return engine != null && engine.toLowerCase().startsWith(MYSQL_ENGINE_PREFIX);
}

public static boolean isStorageParametersModified(final ResourceModel model) {
return StringUtils.hasValue(model.getAllocatedStorage()) ||
Expand All @@ -72,10 +77,40 @@ public static boolean isDBInstanceReadReplica(final ResourceModel model) {
return StringUtils.hasValue(model.getSourceDBInstanceIdentifier());
}

public static boolean isCrossRegionDBInstanceReadReplica(final ResourceModel model, final String currentRegion) {
final String sourceDBInstanceIdentifier = model.getSourceDBInstanceIdentifier();
return isDBInstanceReadReplica(model) &&
isValidArn(sourceDBInstanceIdentifier) &&
!getRegionFromArn(sourceDBInstanceIdentifier).equals(currentRegion);
}
public static boolean isDBClusterReadReplica(final ResourceModel model) {
return StringUtils.hasValue(model.getSourceDBClusterIdentifier());
}

public static boolean isCrossRegionDBClusterReadReplica(final ResourceModel model, final String currentRegion) {
final String sourceDBClusterIdentifier = model.getSourceDBClusterIdentifier();
return isDBClusterReadReplica(model) &&
isValidArn(sourceDBClusterIdentifier) &&
!getRegionFromArn(sourceDBClusterIdentifier).equals(currentRegion);
}

public static boolean isValidArn(final String arn) {
try {
Arn.fromString(arn);
return true;
} catch (IllegalArgumentException e) {
return false;
}

}
public static String getRegionFromArn(final String arn) {
return Arn.fromString(arn).getRegion();
}

public static String getResourceNameFromArn(final String arn) {
return Arn.fromString(arn).getResource().getResource();
}

public static boolean isRestoreFromSnapshot(final ResourceModel model) {
return StringUtils.hasValue(model.getDBSnapshotIdentifier());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ public abstract class AbstractHandlerTest extends AbstractTestBase<DBInstance, R
protected static final String RESTORE_TIME_UTC_PLUS_5 = "2007-04-05T17:30:00+05:00";
protected static final String SOURCE_DB_INSTANCE_AUTOMATED_BACKUPS_ARN_EMPTY = null;
protected static final String SOURCE_DB_INSTANCE_AUTOMATED_BACKUPS_ARN_NON_EMPTY = "arn:aws:rds:us-east-1:123456789012:snapshot:rds:backup-name";
protected static final String SOURCE_DB_INSTANCE_ARN = "arn:aws:rds:us-east-1:123456789012:instance:rds:source-db-instance";
protected static final String SOURCE_DBI_RESOURCE_ID_EMPTY = null;
protected static final String SOURCE_DBI_RESOURCE_ID_NON_EMPTY = "dbi-instance-identifier";
protected static final boolean USE_LATEST_RESTORABLE_TIME_NO = false;
Expand All @@ -194,6 +195,7 @@ public abstract class AbstractHandlerTest extends AbstractTestBase<DBInstance, R
protected static final String MSG_GENERIC_ERR = "Error";
protected static final String AUTOMATIC_BACKUP_REPLICATION_REGION = "eu-west-1";
protected static final String AUTOMATIC_BACKUP_REPLICATION_REGION_ALTER = "eu-west-2";
protected static final String CURRENT_REGION = "eu-west-1";


protected static final ResourceModel RESOURCE_MODEL_NO_IDENTIFIER;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,10 +281,27 @@ public void test_createDBInstanceRequest_stableHashCode() {
}

@Test
public void test_createReadReplicaRequest_parameterGroupNotSet() {
final ResourceModel model = RESOURCE_MODEL_BLDR().build();
public void test_createReadReplicaRequest_parameterGroupSetForMySql() {
final ResourceModel model = RESOURCE_MODEL_BLDR()
.sourceDBInstanceIdentifier(SOURCE_DB_INSTANCE_ARN)
.build();

final CreateDbInstanceReadReplicaRequest request = Translator.createDbInstanceReadReplicaRequest(model, Tagging.TagSet.builder().build(), CURRENT_REGION);
Assertions.assertNotNull(request.dbParameterGroupName());
}

@Test
public void test_createReadReplicaRequest_parameterGroupNotSetForSqlServer() {
final ResourceModel model = RESOURCE_MODEL_BLDR()
.engine(ENGINE_SQLSERVER_SE)
.dBSnapshotIdentifier("snapshot")
.storageType("gp3")
.iops(100)
.storageThroughput(200)
.allocatedStorage("300")
.build();

final CreateDbInstanceReadReplicaRequest request = Translator.createDbInstanceReadReplicaRequest(model, Tagging.TagSet.builder().build());
final CreateDbInstanceReadReplicaRequest request = Translator.createDbInstanceReadReplicaRequest(model, Tagging.TagSet.builder().build(), CURRENT_REGION);
Assertions.assertNull(request.dbParameterGroupName());
}

Expand All @@ -293,7 +310,7 @@ public void test_createReadReplicaRequest_blankSourceRegionIsNotSet() {
final ResourceModel model = ResourceModel.builder()
.sourceRegion("")
.build();
final CreateDbInstanceReadReplicaRequest request = Translator.createDbInstanceReadReplicaRequest(model, Tagging.TagSet.builder().build());
final CreateDbInstanceReadReplicaRequest request = Translator.createDbInstanceReadReplicaRequest(model, Tagging.TagSet.builder().build(), CURRENT_REGION);
Assertions.assertNull(request.sourceRegion());
}

Expand All @@ -303,7 +320,7 @@ public void test_createReadReplicaRequest_nonBlankSourceRegionIsSet() {
final ResourceModel model = ResourceModel.builder()
.sourceRegion(sourceRegion)
.build();
final CreateDbInstanceReadReplicaRequest request = Translator.createDbInstanceReadReplicaRequest(model, Tagging.TagSet.builder().build());
final CreateDbInstanceReadReplicaRequest request = Translator.createDbInstanceReadReplicaRequest(model, Tagging.TagSet.builder().build(), CURRENT_REGION);
Assertions.assertEquals(sourceRegion, request.sourceRegion());
}

Expand Down

0 comments on commit 82e98c7

Please sign in to comment.