From b45f1b79796b3b95f27659aef7f8ab678d3ef96b Mon Sep 17 00:00:00 2001 From: AbdulRehman Date: Mon, 26 Feb 2024 11:08:47 -0500 Subject: [PATCH] JDBC QPT Implementation; with MySQL Connector (#1774) --- .../jdbc/manager/JdbcArrowTypeConverter.java | 6 +- .../jdbc/manager/JdbcMetadataHandler.java | 84 +++++++++++++++++++ .../jdbc/manager/JdbcRecordHandler.java | 12 +++ .../jdbc/qpt/JdbcQueryPassthrough.java | 72 ++++++++++++++++ .../mysql/MySqlMetadataHandler.java | 8 +- .../connectors/mysql/MySqlRecordHandler.java | 8 +- 6 files changed, 187 insertions(+), 3 deletions(-) create mode 100644 athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/qpt/JdbcQueryPassthrough.java diff --git a/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcArrowTypeConverter.java b/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcArrowTypeConverter.java index bf1105dcc6..d5778b462f 100644 --- a/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcArrowTypeConverter.java +++ b/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcArrowTypeConverter.java @@ -48,12 +48,16 @@ public static ArrowType toArrowType(final int jdbcType, final int precision, fin int defaultScale = Integer.parseInt(configOptions.getOrDefault("default_scale", "0")); int resolvedPrecision = precision; int resolvedScale = scale; - boolean needsResolving = jdbcType == Types.NUMERIC && (precision == 0 && scale == 0); + boolean needsResolving = jdbcType == Types.NUMERIC && (precision == 0 && scale <= 0); + boolean decimalExceedingPrecision = jdbcType == Types.DECIMAL && precision > DEFAULT_PRECISION; // Resolve Precision and Scale if they're not available if (needsResolving) { resolvedPrecision = DEFAULT_PRECISION; resolvedScale = defaultScale; } + else if (decimalExceedingPrecision) { + resolvedPrecision = DEFAULT_PRECISION; + } ArrowType arrowType = JdbcToArrowUtils.getArrowTypeFromJdbcType( new JdbcFieldInfo(jdbcType, resolvedPrecision, resolvedScale), diff --git a/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcMetadataHandler.java b/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcMetadataHandler.java index 77939d826e..a542182190 100644 --- a/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcMetadataHandler.java +++ b/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcMetadataHandler.java @@ -25,7 +25,9 @@ import com.amazonaws.athena.connector.lambda.data.FieldBuilder; import com.amazonaws.athena.connector.lambda.data.SchemaBuilder; import com.amazonaws.athena.connector.lambda.data.SupportedTypes; +import com.amazonaws.athena.connector.lambda.domain.Split; import com.amazonaws.athena.connector.lambda.domain.TableName; +import com.amazonaws.athena.connector.lambda.domain.spill.SpillLocation; import com.amazonaws.athena.connector.lambda.handlers.MetadataHandler; import com.amazonaws.athena.connector.lambda.metadata.GetSplitsRequest; import com.amazonaws.athena.connector.lambda.metadata.GetSplitsResponse; @@ -40,6 +42,7 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.jdbc.connection.RdsSecretsCredentialProvider; +import com.amazonaws.athena.connectors.jdbc.qpt.JdbcQueryPassthrough; import com.amazonaws.athena.connectors.jdbc.splits.Splitter; import com.amazonaws.athena.connectors.jdbc.splits.SplitterFactory; import com.amazonaws.services.athena.AmazonAthena; @@ -58,11 +61,15 @@ import java.sql.Connection; import java.sql.DatabaseMetaData; +import java.sql.PreparedStatement; import java.sql.ResultSet; +import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Statement; import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -82,6 +89,7 @@ public abstract class JdbcMetadataHandler private final JdbcConnectionFactory jdbcConnectionFactory; private final DatabaseConnectionConfig databaseConnectionConfig; private final SplitterFactory splitterFactory = new SplitterFactory(); + protected JdbcQueryPassthrough jdbcQueryPassthrough = new JdbcQueryPassthrough(); /** * Used only by Multiplexing handler. All calls will be delegated to respective database handler. @@ -235,11 +243,70 @@ public GetTableResponse doGetTable(final BlockAllocator blockAllocator, final Ge TableName caseInsensitiveTableMatch = caseInsensitiveTableSearch(connection, getTableRequest.getTableName().getSchemaName(), getTableRequest.getTableName().getTableName()); Schema caseInsensitiveSchemaMatch = getSchema(connection, caseInsensitiveTableMatch, partitionSchema); + return new GetTableResponse(getTableRequest.getCatalogName(), caseInsensitiveTableMatch, caseInsensitiveSchemaMatch, partitionSchema.getFields().stream().map(Field::getName).collect(Collectors.toSet())); } } + @Override + public GetTableResponse doGetQueryPassthroughSchema(final BlockAllocator blockAllocator, final GetTableRequest getTableRequest) + throws Exception + { + if (!getTableRequest.isQueryPassthrough()) { + throw new IllegalArgumentException("No Query passed through [{}]" + getTableRequest); + } + + jdbcQueryPassthrough.verify(getTableRequest.getQueryPassthroughArguments()); + String customerPassedQuery = getTableRequest.getQueryPassthroughArguments().get(JdbcQueryPassthrough.QUERY); + + try (Connection connection = getJdbcConnectionFactory().getConnection(getCredentialProvider())) { + PreparedStatement preparedStatement = connection.prepareStatement(customerPassedQuery); + ResultSetMetaData metadata = preparedStatement.getMetaData(); + if (metadata == null) { + throw new UnsupportedOperationException("Query not supported: ResultSetMetaData not available for query: " + customerPassedQuery); + } + SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder(); + + for (int columnIndex = 1; columnIndex <= metadata.getColumnCount(); columnIndex++) { + String columnName = metadata.getColumnName(columnIndex); + String columnLabel = metadata.getColumnLabel(columnIndex); + //todo; is there a mechanism to pass both back to the engine? + columnName = columnName.equals(columnLabel) ? columnName : columnLabel; + + int precision = metadata.getPrecision(columnIndex); + int scale = metadata.getScale(columnIndex); + + ArrowType columnType = JdbcArrowTypeConverter.toArrowType( + metadata.getColumnType(columnIndex), + precision, + scale, + configOptions); + + if (columnType != null && SupportedTypes.isSupported(columnType)) { + if (columnType instanceof ArrowType.List) { + schemaBuilder.addListField(columnName, getArrayArrowTypeFromTypeName( + metadata.getTableName(columnIndex), + metadata.getColumnDisplaySize(columnIndex), + precision)); + } + else { + schemaBuilder.addField(FieldBuilder.newBuilder(columnName, columnType).build()); + } + } + else { + // Default to VARCHAR ArrowType + LOGGER.warn("getSchema: Unable to map type for column[" + columnName + + "] to a supported type, attempted " + columnType + " - defaulting type to VARCHAR."); + schemaBuilder.addField(FieldBuilder.newBuilder(columnName, new ArrowType.Utf8()).build()); + } + } + + Schema schema = schemaBuilder.build(); + return new GetTableResponse(getTableRequest.getCatalogName(), getTableRequest.getTableName(), schema, Collections.emptySet()); + } + } + /** * While being a no-op by default, this function will be overriden by subclasses that support this search. * @@ -373,4 +440,21 @@ protected ArrowType getArrayArrowTypeFromTypeName(String typeName, int precision // Default ARRAY type is VARCHAR. return new ArrowType.Utf8(); } + + /** + * Helper function that provides a single partition for Query Pass-Through + * + */ + protected GetSplitsResponse setupQueryPassthroughSplit(GetSplitsRequest request) + { + //Every split must have a unique location if we wish to spill to avoid failures + SpillLocation spillLocation = makeSpillLocation(request); + + //Since this is QPT query we return a fixed split. + Map qptArguments = request.getConstraints().getQueryPassthroughArguments(); + return new GetSplitsResponse(request.getCatalogName(), + Split.newBuilder(spillLocation, makeEncryptionKey()) + .applyProperties(qptArguments) + .build()); + } } diff --git a/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcRecordHandler.java b/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcRecordHandler.java index dac0c784af..3c431efb5c 100644 --- a/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcRecordHandler.java +++ b/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcRecordHandler.java @@ -53,6 +53,7 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.jdbc.connection.RdsSecretsCredentialProvider; +import com.amazonaws.athena.connectors.jdbc.qpt.JdbcQueryPassthrough; import com.amazonaws.services.athena.AmazonAthena; import com.amazonaws.services.s3.AmazonS3; import com.amazonaws.services.secretsmanager.AWSSecretsManager; @@ -97,6 +98,8 @@ public abstract class JdbcRecordHandler private final JdbcConnectionFactory jdbcConnectionFactory; private final DatabaseConnectionConfig databaseConnectionConfig; + protected final JdbcQueryPassthrough queryPassthrough = new JdbcQueryPassthrough(); + /** * Used only by Multiplexing handler. All invocations will be delegated to respective database handler. */ @@ -316,4 +319,13 @@ protected Extractor makeExtractor(Field field, ResultSet resultSet, Map ARGUMENTS = Arrays.asList(QUERY); + + private static final Logger LOGGER = LoggerFactory.getLogger(JdbcQueryPassthrough.class); + + @Override + public String getFunctionSchema() + { + return SCHEMA_NAME; + } + + @Override + public String getFunctionName() + { + return NAME; + } + + @Override + public List getFunctionArguments() + { + return ARGUMENTS; + } + + @Override + public Logger getLogger() + { + return LOGGER; + } +} diff --git a/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlMetadataHandler.java b/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlMetadataHandler.java index 9fb4209e7d..d569f8dfe6 100644 --- a/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlMetadataHandler.java +++ b/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlMetadataHandler.java @@ -66,7 +66,6 @@ import java.util.Map; import java.util.Set; -import static com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler.TABLES_AND_VIEWS; import static com.amazonaws.athena.connectors.mysql.MySqlConstants.MYSQL_DEFAULT_PORT; import static com.amazonaws.athena.connectors.mysql.MySqlConstants.MYSQL_DRIVER_CLASS; import static com.amazonaws.athena.connectors.mysql.MySqlConstants.MYSQL_NAME; @@ -137,6 +136,8 @@ public GetDataSourceCapabilitiesResponse doGetDataSourceCapabilities(BlockAlloca .toArray(String[]::new)) )); + jdbcQueryPassthrough.addQueryPassthroughCapabilityIfEnabled(capabilities, configOptions); + return new GetDataSourceCapabilitiesResponse(request.getCatalogName(), capabilities.build()); } @@ -193,6 +194,11 @@ public GetSplitsResponse doGetSplits( final BlockAllocator blockAllocator, final GetSplitsRequest getSplitsRequest) { LOGGER.info("{}: Catalog {}, table {}", getSplitsRequest.getQueryId(), getSplitsRequest.getTableName().getSchemaName(), getSplitsRequest.getTableName().getTableName()); + if (getSplitsRequest.getConstraints().isQueryPassThrough()) { + LOGGER.info("QPT Split Requested"); + return setupQueryPassthroughSplit(getSplitsRequest); + } + int partitionContd = decodeContinuationToken(getSplitsRequest); Set splits = new HashSet<>(); Block partitions = getSplitsRequest.getPartitions(); diff --git a/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlRecordHandler.java b/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlRecordHandler.java index d662cacbcf..57a60f24e8 100644 --- a/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlRecordHandler.java +++ b/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlRecordHandler.java @@ -95,8 +95,14 @@ public MySqlRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, Jdb public PreparedStatement buildSplitSql(Connection jdbcConnection, String catalogName, TableName tableName, Schema schema, Constraints constraints, Split split) throws SQLException { - PreparedStatement preparedStatement = jdbcSplitQueryBuilder.buildSql(jdbcConnection, null, tableName.getSchemaName(), tableName.getTableName(), schema, constraints, split); + PreparedStatement preparedStatement; + if (constraints.isQueryPassThrough()) { + preparedStatement = buildQueryPassthroughSql(jdbcConnection, constraints); + } + else { + preparedStatement = jdbcSplitQueryBuilder.buildSql(jdbcConnection, null, tableName.getSchemaName(), tableName.getTableName(), schema, constraints, split); + } // Disable fetching all rows. preparedStatement.setFetchSize(Integer.MIN_VALUE);