Skip to content

Commit

Permalink
Custom schema write updates with logging
Browse files Browse the repository at this point in the history
  • Loading branch information
xiazcy committed Apr 22, 2024
1 parent 055280b commit ba27ee8
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import com.amazonaws.athena.connectors.neptune.propertygraph.NeptuneGremlinConnection;
import com.amazonaws.athena.connectors.neptune.rdf.NeptuneSparqlConnection;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.neptune.auth.NeptuneNettyHttpSigV4Signer;
import com.amazonaws.neptune.auth.NeptuneSigV4SignerException;
Expand All @@ -29,11 +30,14 @@
import org.apache.tinkerpop.gremlin.driver.remote.DriverRemoteConnection;
import org.apache.tinkerpop.gremlin.process.traversal.AnonymousTraversalSource;
import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.GraphTraversalSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NeptuneConnection
{
private static Cluster cluster = null;

private static final Logger logger = LoggerFactory.getLogger(NeptuneConnection.class);

private String neptuneEndpoint;
private String neptunePort;
private boolean enabledIAM;
Expand All @@ -47,14 +51,17 @@ protected NeptuneConnection(String neptuneEndpoint, String neptunePort, boolean
.enableSsl(true);

if (enabledIAM) {
logger.info("Connecting with IAM auth to https://" + neptuneEndpoint + ":" + neptunePort + " in " + region);
final AWSCredentialsProvider awsCredentialsProvider = new DefaultAWSCredentialsProviderChain();
builder.handshakeInterceptor(r ->
{
try {
NeptuneNettyHttpSigV4Signer sigV4Signer =
new NeptuneNettyHttpSigV4Signer(region, new DefaultAWSCredentialsProviderChain());
new NeptuneNettyHttpSigV4Signer(region, awsCredentialsProvider);
sigV4Signer.signRequest(r);
}
catch (NeptuneSigV4SignerException e) {
logger.error("SIGV4 exception", e);
throw new RuntimeException("Exception occurred while signing the request", e);
}
return r;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package com.amazonaws.athena.connectors.neptune.propertygraph;

import com.amazonaws.athena.connectors.neptune.NeptuneConnection;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.neptune.auth.NeptuneNettyHttpSigV4Signer;
import com.amazonaws.neptune.auth.NeptuneSigV4SignerException;
Expand All @@ -28,9 +29,12 @@
import org.apache.tinkerpop.gremlin.driver.remote.DriverRemoteConnection;
import org.apache.tinkerpop.gremlin.process.traversal.AnonymousTraversalSource;
import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.GraphTraversalSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NeptuneGremlinConnection extends NeptuneConnection
{
private static final Logger logger = LoggerFactory.getLogger(NeptuneGremlinConnection.class);
private static Cluster cluster = null;

public NeptuneGremlinConnection(String neptuneEndpoint, String neptunePort, boolean enabledIAM, String region)
Expand All @@ -42,14 +46,17 @@ public NeptuneGremlinConnection(String neptuneEndpoint, String neptunePort, bool
.enableSsl(true);

if (enabledIAM) {
logger.info("Connecting with IAM auth to https://" + neptuneEndpoint + ":" + neptunePort + " in " + region);
final AWSCredentialsProvider awsCredentialsProvider = new DefaultAWSCredentialsProviderChain();
builder.handshakeInterceptor(r ->
{
try {
NeptuneNettyHttpSigV4Signer sigV4Signer =
new NeptuneNettyHttpSigV4Signer(region, new DefaultAWSCredentialsProviderChain());
new NeptuneNettyHttpSigV4Signer(region, awsCredentialsProvider);
sigV4Signer.signRequest(r);
}
catch (NeptuneSigV4SignerException e) {
logger.error("SIGV4 exception", e);
throw new RuntimeException("Exception occurred while signing the request", e);
}
return r;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.tinkerpop.gremlin.structure.T;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Date;
Expand All @@ -52,6 +54,7 @@
*/
public final class CustomSchemaRowWriter
{
private static final Logger logger = LoggerFactory.getLogger(CustomSchemaRowWriter.class);
private CustomSchemaRowWriter()
{
// Empty private constructor
Expand All @@ -61,8 +64,10 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie
{
ArrowType arrowType = field.getType();
Types.MinorType minorType = Types.getMinorTypeForArrowType(arrowType);
logger.debug("writeRowTemplate*" + field.getName() + "*" + minorType + "*");
Boolean enableCaseinsensitivematch = (configOptions.get(Constants.SCHEMA_CASE_INSEN) == null) ? true : Boolean.parseBoolean(configOptions.get(Constants.SCHEMA_CASE_INSEN));

try {
switch (minorType) {
case BIT:
rowWriterBuilder.withExtractor(field.getName(),
Expand All @@ -72,6 +77,9 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie
value.isSet = 0;

Object fieldValue = obj.get(fieldName);
logger.debug("writeRowTemplate BIT*" + field.getName() + "*" + minorType + "*"
+ (fieldValue == null ? "" : fieldValue.getClass()) + "*");

if (fieldValue.getClass().equals(Boolean.class)) {
Boolean booleanValue = Boolean.parseBoolean(fieldValue.toString());
value.value = booleanValue ? 1 : 0;
Expand Down Expand Up @@ -105,20 +113,26 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie
}
else {
Object fieldValue = obj.get(fieldName);
logger.debug("writeRowTemplate VARCHAR*" + field.getName() + "*" + minorType + "*"
+ (fieldValue == null ? "" : fieldValue.getClass()) + "*");

if (fieldValue != null) {
if (fieldValue.getClass().equals(String.class)) {
value.value = fieldValue.toString();
value.isSet = 1;
}
else {
else if (fieldValue instanceof ArrayList) {
ArrayList<Object> objValues = (ArrayList) fieldValue;
if (objValues != null && objValues.get(0) != null) {
value.value = objValues.get(0).toString();
value.isSet = 1;
}
}
}
else {
value.value = "" + fieldValue;
value.isSet = 1;
}
}
}
});
break;
Expand All @@ -131,6 +145,8 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie
value.isSet = 0;

Object fieldValue = obj.get(fieldName);
logger.debug("writeRowTemplate DATEMILLI*" + field.getName() + "*" + minorType + "*"
+ (fieldValue == null ? "" : fieldValue.getClass()) + "*");
if (fieldValue.getClass().equals(Date.class)) {
value.value = ((Date) fieldValue).getTime();
value.isSet = 1;
Expand All @@ -153,11 +169,13 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie
value.isSet = 0;

Object fieldValue = obj.get(fieldName);
if (fieldValue.getClass().equals(Integer.class)) {
logger.debug("writeRowTemplate INT*" + field.getName() + "*" + minorType + "*"
+ (fieldValue == null ? "" : fieldValue.getClass()) + "*");
if (fieldValue.getClass().equals(Integer.class)) {
value.value = Integer.parseInt(fieldValue.toString());
value.isSet = 1;
}
else {
else if (fieldValue instanceof ArrayList) {
ArrayList<Object> objValues = (ArrayList) fieldValue;
if (objValues != null && objValues.get(0) != null && !(objValues.get(0).toString().trim().isEmpty())) {
value.value = Integer.parseInt(objValues.get(0).toString());
Expand All @@ -175,11 +193,13 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie
value.isSet = 0;

Object fieldValue = obj.get(fieldName);
if (fieldValue.getClass().equals(Long.class)) {
logger.debug("writeRowTemplate BIGINT*" + field.getName() + "*" + minorType + "*"
+ (fieldValue == null ? "" : fieldValue.getClass()) + "*");
if (fieldValue.getClass().equals(Long.class)) {
value.value = Long.parseLong(fieldValue.toString());
value.isSet = 1;
}
else {
else if (fieldValue instanceof ArrayList) {
ArrayList<Object> objValues = (ArrayList) fieldValue;
if (objValues != null && objValues.get(0) != null && !(objValues.get(0).toString().trim().isEmpty())) {
value.value = Long.parseLong(objValues.get(0).toString());
Expand All @@ -197,11 +217,13 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie
value.isSet = 0;

Object fieldValue = obj.get(fieldName);
if (fieldValue.getClass().equals(Float.class)) {
logger.debug("writeRowTemplate FLOAT4*" + field.getName() + "*" + minorType + "*"
+ (fieldValue == null ? "" : fieldValue.getClass()) + "*");
if (fieldValue.getClass().equals(Float.class)) {
value.value = Float.parseFloat(fieldValue.toString());
value.isSet = 1;
}
else {
else if (fieldValue instanceof ArrayList) {
ArrayList<Object> objValues = (ArrayList) fieldValue;
if (objValues != null && objValues.get(0) != null && !(objValues.get(0).toString().trim().isEmpty())) {
value.value = Float.parseFloat(objValues.get(0).toString());
Expand All @@ -218,12 +240,14 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie
Map<String, Object> obj = (Map<String, Object>) contextAsMap(context, enableCaseinsensitivematch);
value.isSet = 0;

Object fieldValue = obj.get(field.getName());
if (fieldValue.getClass().equals(Double.class)) {
Object fieldValue = obj.get(fieldName);
logger.debug("writeRowTemplate FLOAT8*" + field.getName() + "*" + minorType + "*"
+ (fieldValue == null ? "" : fieldValue.getClass()) + "*");
if (fieldValue.getClass().equals(Double.class)) {
value.value = Double.parseDouble(fieldValue.toString());
value.isSet = 1;
}
else {
else if (fieldValue instanceof ArrayList) {
ArrayList<Object> objValues = (ArrayList) fieldValue;
if (objValues != null && objValues.get(0) != null && !(objValues.get(0).toString().trim().isEmpty())) {
value.value = Double.parseDouble(objValues.get(0).toString());
Expand All @@ -234,6 +258,11 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie

break;
}
}
catch (Throwable e) {
logger.error("writeRowTemplate exception for *" + field.getName() + "*" + minorType + "*", e);
throw new RuntimeException(e);
}
}

private static Map<String, Object> contextAsMap(Object context, boolean caseInsensitive)
Expand Down

0 comments on commit ba27ee8

Please sign in to comment.