Skip to content

Commit

Permalink
JDBC QPT Implementation; with MySQL Connector (#1774)
Browse files Browse the repository at this point in the history
  • Loading branch information
AbdulR3hman authored Feb 26, 2024
1 parent 1d0afec commit b45f1b7
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,16 @@ public static ArrowType toArrowType(final int jdbcType, final int precision, fin
int defaultScale = Integer.parseInt(configOptions.getOrDefault("default_scale", "0"));
int resolvedPrecision = precision;
int resolvedScale = scale;
boolean needsResolving = jdbcType == Types.NUMERIC && (precision == 0 && scale == 0);
boolean needsResolving = jdbcType == Types.NUMERIC && (precision == 0 && scale <= 0);
boolean decimalExceedingPrecision = jdbcType == Types.DECIMAL && precision > DEFAULT_PRECISION;
// Resolve Precision and Scale if they're not available
if (needsResolving) {
resolvedPrecision = DEFAULT_PRECISION;
resolvedScale = defaultScale;
}
else if (decimalExceedingPrecision) {
resolvedPrecision = DEFAULT_PRECISION;
}

ArrowType arrowType = JdbcToArrowUtils.getArrowTypeFromJdbcType(
new JdbcFieldInfo(jdbcType, resolvedPrecision, resolvedScale),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
import com.amazonaws.athena.connector.lambda.data.FieldBuilder;
import com.amazonaws.athena.connector.lambda.data.SchemaBuilder;
import com.amazonaws.athena.connector.lambda.data.SupportedTypes;
import com.amazonaws.athena.connector.lambda.domain.Split;
import com.amazonaws.athena.connector.lambda.domain.TableName;
import com.amazonaws.athena.connector.lambda.domain.spill.SpillLocation;
import com.amazonaws.athena.connector.lambda.handlers.MetadataHandler;
import com.amazonaws.athena.connector.lambda.metadata.GetSplitsRequest;
import com.amazonaws.athena.connector.lambda.metadata.GetSplitsResponse;
Expand All @@ -40,6 +42,7 @@
import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory;
import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider;
import com.amazonaws.athena.connectors.jdbc.connection.RdsSecretsCredentialProvider;
import com.amazonaws.athena.connectors.jdbc.qpt.JdbcQueryPassthrough;
import com.amazonaws.athena.connectors.jdbc.splits.Splitter;
import com.amazonaws.athena.connectors.jdbc.splits.SplitterFactory;
import com.amazonaws.services.athena.AmazonAthena;
Expand All @@ -58,11 +61,15 @@

import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
Expand All @@ -82,6 +89,7 @@ public abstract class JdbcMetadataHandler
private final JdbcConnectionFactory jdbcConnectionFactory;
private final DatabaseConnectionConfig databaseConnectionConfig;
private final SplitterFactory splitterFactory = new SplitterFactory();
protected JdbcQueryPassthrough jdbcQueryPassthrough = new JdbcQueryPassthrough();

/**
* Used only by Multiplexing handler. All calls will be delegated to respective database handler.
Expand Down Expand Up @@ -235,11 +243,70 @@ public GetTableResponse doGetTable(final BlockAllocator blockAllocator, final Ge
TableName caseInsensitiveTableMatch = caseInsensitiveTableSearch(connection, getTableRequest.getTableName().getSchemaName(),
getTableRequest.getTableName().getTableName());
Schema caseInsensitiveSchemaMatch = getSchema(connection, caseInsensitiveTableMatch, partitionSchema);

return new GetTableResponse(getTableRequest.getCatalogName(), caseInsensitiveTableMatch, caseInsensitiveSchemaMatch,
partitionSchema.getFields().stream().map(Field::getName).collect(Collectors.toSet()));
}
}

@Override
public GetTableResponse doGetQueryPassthroughSchema(final BlockAllocator blockAllocator, final GetTableRequest getTableRequest)
throws Exception
{
if (!getTableRequest.isQueryPassthrough()) {
throw new IllegalArgumentException("No Query passed through [{}]" + getTableRequest);
}

jdbcQueryPassthrough.verify(getTableRequest.getQueryPassthroughArguments());
String customerPassedQuery = getTableRequest.getQueryPassthroughArguments().get(JdbcQueryPassthrough.QUERY);

try (Connection connection = getJdbcConnectionFactory().getConnection(getCredentialProvider())) {
PreparedStatement preparedStatement = connection.prepareStatement(customerPassedQuery);
ResultSetMetaData metadata = preparedStatement.getMetaData();
if (metadata == null) {
throw new UnsupportedOperationException("Query not supported: ResultSetMetaData not available for query: " + customerPassedQuery);
}
SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder();

for (int columnIndex = 1; columnIndex <= metadata.getColumnCount(); columnIndex++) {
String columnName = metadata.getColumnName(columnIndex);
String columnLabel = metadata.getColumnLabel(columnIndex);
//todo; is there a mechanism to pass both back to the engine?
columnName = columnName.equals(columnLabel) ? columnName : columnLabel;

int precision = metadata.getPrecision(columnIndex);
int scale = metadata.getScale(columnIndex);

ArrowType columnType = JdbcArrowTypeConverter.toArrowType(
metadata.getColumnType(columnIndex),
precision,
scale,
configOptions);

if (columnType != null && SupportedTypes.isSupported(columnType)) {
if (columnType instanceof ArrowType.List) {
schemaBuilder.addListField(columnName, getArrayArrowTypeFromTypeName(
metadata.getTableName(columnIndex),
metadata.getColumnDisplaySize(columnIndex),
precision));
}
else {
schemaBuilder.addField(FieldBuilder.newBuilder(columnName, columnType).build());
}
}
else {
// Default to VARCHAR ArrowType
LOGGER.warn("getSchema: Unable to map type for column[" + columnName +
"] to a supported type, attempted " + columnType + " - defaulting type to VARCHAR.");
schemaBuilder.addField(FieldBuilder.newBuilder(columnName, new ArrowType.Utf8()).build());
}
}

Schema schema = schemaBuilder.build();
return new GetTableResponse(getTableRequest.getCatalogName(), getTableRequest.getTableName(), schema, Collections.emptySet());
}
}

/**
* While being a no-op by default, this function will be overriden by subclasses that support this search.
*
Expand Down Expand Up @@ -373,4 +440,21 @@ protected ArrowType getArrayArrowTypeFromTypeName(String typeName, int precision
// Default ARRAY type is VARCHAR.
return new ArrowType.Utf8();
}

/**
* Helper function that provides a single partition for Query Pass-Through
*
*/
protected GetSplitsResponse setupQueryPassthroughSplit(GetSplitsRequest request)
{
//Every split must have a unique location if we wish to spill to avoid failures
SpillLocation spillLocation = makeSpillLocation(request);

//Since this is QPT query we return a fixed split.
Map<String, String> qptArguments = request.getConstraints().getQueryPassthroughArguments();
return new GetSplitsResponse(request.getCatalogName(),
Split.newBuilder(spillLocation, makeEncryptionKey())
.applyProperties(qptArguments)
.build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory;
import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider;
import com.amazonaws.athena.connectors.jdbc.connection.RdsSecretsCredentialProvider;
import com.amazonaws.athena.connectors.jdbc.qpt.JdbcQueryPassthrough;
import com.amazonaws.services.athena.AmazonAthena;
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.secretsmanager.AWSSecretsManager;
Expand Down Expand Up @@ -97,6 +98,8 @@ public abstract class JdbcRecordHandler
private final JdbcConnectionFactory jdbcConnectionFactory;
private final DatabaseConnectionConfig databaseConnectionConfig;

protected final JdbcQueryPassthrough queryPassthrough = new JdbcQueryPassthrough();

/**
* Used only by Multiplexing handler. All invocations will be delegated to respective database handler.
*/
Expand Down Expand Up @@ -316,4 +319,13 @@ protected Extractor makeExtractor(Field field, ResultSet resultSet, Map<String,
*/
public abstract PreparedStatement buildSplitSql(Connection jdbcConnection, String catalogName, TableName tableName, Schema schema, Constraints constraints, Split split)
throws SQLException;

public PreparedStatement buildQueryPassthroughSql(Connection jdbcConnection, Constraints constraints) throws SQLException
{
PreparedStatement preparedStatement;
queryPassthrough.verify(constraints.getQueryPassthroughArguments());
String clientPassQuery = constraints.getQueryPassthroughArguments().get(JdbcQueryPassthrough.QUERY);
preparedStatement = jdbcConnection.prepareStatement(clientPassQuery);
return preparedStatement;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*-
* #%L
* athena-jdbc
* %%
* Copyright (C) 2019 Amazon Web Services
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
package com.amazonaws.athena.connectors.jdbc.qpt;

import com.amazonaws.athena.connector.lambda.metadata.optimizations.querypassthrough.QueryPassthroughSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Arrays;
import java.util.List;

/**
* A Singleton class that implements QPT signature interface to define
* the JDBC Query Passthrough Function's signature that will be used
* to inform the engine how to define QPT Function for a JDBC connector
*/
public class JdbcQueryPassthrough implements QueryPassthroughSignature
{
// Constant value representing the name of the query.
public static final String NAME = "query";

// Constant value representing the domain of the query.
public static final String SCHEMA_NAME = "system";

// List of arguments for the query, statically initialized as it always contains the same value.
public static final String QUERY = "QUERY";

public static final List<String> ARGUMENTS = Arrays.asList(QUERY);

private static final Logger LOGGER = LoggerFactory.getLogger(JdbcQueryPassthrough.class);

@Override
public String getFunctionSchema()
{
return SCHEMA_NAME;
}

@Override
public String getFunctionName()
{
return NAME;
}

@Override
public List<String> getFunctionArguments()
{
return ARGUMENTS;
}

@Override
public Logger getLogger()
{
return LOGGER;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@
import java.util.Map;
import java.util.Set;

import static com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler.TABLES_AND_VIEWS;
import static com.amazonaws.athena.connectors.mysql.MySqlConstants.MYSQL_DEFAULT_PORT;
import static com.amazonaws.athena.connectors.mysql.MySqlConstants.MYSQL_DRIVER_CLASS;
import static com.amazonaws.athena.connectors.mysql.MySqlConstants.MYSQL_NAME;
Expand Down Expand Up @@ -137,6 +136,8 @@ public GetDataSourceCapabilitiesResponse doGetDataSourceCapabilities(BlockAlloca
.toArray(String[]::new))
));

jdbcQueryPassthrough.addQueryPassthroughCapabilityIfEnabled(capabilities, configOptions);

return new GetDataSourceCapabilitiesResponse(request.getCatalogName(), capabilities.build());
}

Expand Down Expand Up @@ -193,6 +194,11 @@ public GetSplitsResponse doGetSplits(
final BlockAllocator blockAllocator, final GetSplitsRequest getSplitsRequest)
{
LOGGER.info("{}: Catalog {}, table {}", getSplitsRequest.getQueryId(), getSplitsRequest.getTableName().getSchemaName(), getSplitsRequest.getTableName().getTableName());
if (getSplitsRequest.getConstraints().isQueryPassThrough()) {
LOGGER.info("QPT Split Requested");
return setupQueryPassthroughSplit(getSplitsRequest);
}

int partitionContd = decodeContinuationToken(getSplitsRequest);
Set<Split> splits = new HashSet<>();
Block partitions = getSplitsRequest.getPartitions();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,14 @@ public MySqlRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, Jdb
public PreparedStatement buildSplitSql(Connection jdbcConnection, String catalogName, TableName tableName, Schema schema, Constraints constraints, Split split)
throws SQLException
{
PreparedStatement preparedStatement = jdbcSplitQueryBuilder.buildSql(jdbcConnection, null, tableName.getSchemaName(), tableName.getTableName(), schema, constraints, split);
PreparedStatement preparedStatement;

if (constraints.isQueryPassThrough()) {
preparedStatement = buildQueryPassthroughSql(jdbcConnection, constraints);
}
else {
preparedStatement = jdbcSplitQueryBuilder.buildSql(jdbcConnection, null, tableName.getSchemaName(), tableName.getTableName(), schema, constraints, split);
}
// Disable fetching all rows.
preparedStatement.setFetchSize(Integer.MIN_VALUE);

Expand Down

0 comments on commit b45f1b7

Please sign in to comment.