Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DRILL-8465: Check Input Data for Iceberg Plugin #2853

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.ObjectStreamClass;
import java.util.Base64;
import java.util.Objects;
import java.util.StringJoiner;
Expand Down Expand Up @@ -92,7 +94,8 @@ public IcebergWorkDeserializer() {
public IcebergWork deserialize(JsonParser p, DeserializationContext ctxt) throws IOException {
JsonNode node = p.getCodec().readTree(p);
String scanTaskString = node.get(IcebergWorkSerializer.SCAN_TASK_FIELD).asText();
try (ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(Base64.getDecoder().decode(scanTaskString)))) {
try (ObjectInputStream ois = new ScanTaskObjectInputStream(
new ByteArrayInputStream(Base64.getDecoder().decode(scanTaskString)))) {
Object scanTask = ois.readObject();
return new IcebergWork((CombinedScanTask) scanTask);
} catch (ClassNotFoundException e) {
Expand All @@ -103,6 +106,35 @@ public IcebergWork deserialize(JsonParser p, DeserializationContext ctxt) throws
}
}

private static class ScanTaskObjectInputStream extends ObjectInputStream {

ScanTaskObjectInputStream(InputStream inputStream) throws IOException {
super(inputStream);
}

@Override
protected Class<?> resolveClass(ObjectStreamClass cls) throws IOException, ClassNotFoundException {
final String className = cls.getName();
if (isValidPackage(className)) {
return super.resolveClass(cls);
}
final Class<?> resolvedClass = super.resolveClass(cls);
if ((resolvedClass.isArray() &&
(resolvedClass.getComponentType().isPrimitive() ||
isValidPackage(resolvedClass.getComponentType().getName())))
|| resolvedClass.isPrimitive()) {
return resolvedClass;
}
throw new IOException("Rejected deserialization of unexpected class: " + className);
}

private boolean isValidPackage(final String className) {
return className.startsWith("org.apache.iceberg.") ||
className.startsWith("org.apache.drill.") ||
className.startsWith("java.");
}
}

/**
* Special serializer for {@link IcebergWork} class that serializes
* {@code scanTask} field to byte array string created using {@link java.io.Serializable}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ public class IcebergFormatPluginConfig implements FormatPluginConfig {

private final Boolean ignoreResiduals;

private final Boolean allowAnyClassToBeLoaded;

private final Long snapshotId;

private final Long snapshotAsOfTime;
Expand All @@ -60,6 +62,7 @@ public IcebergFormatPluginConfig(
this.caseSensitive = builder.caseSensitive;
this.includeColumnStats = builder.includeColumnStats;
this.ignoreResiduals = builder.ignoreResiduals;
this.allowAnyClassToBeLoaded = builder.allowAnyClassToBeLoaded;
this.snapshotId = builder.snapshotId;
this.snapshotAsOfTime = builder.snapshotAsOfTime;
this.fromSnapshotId = builder.fromSnapshotId;
Expand Down Expand Up @@ -100,6 +103,10 @@ public Boolean getIgnoreResiduals() {
return this.ignoreResiduals;
}

public Boolean getAllowAnyClassToBeLoaded() {
return this.allowAnyClassToBeLoaded;
}

public Long getSnapshotId() {
return this.snapshotId;
}
Expand Down Expand Up @@ -130,6 +137,7 @@ public boolean equals(Object o) {
&& Objects.equals(caseSensitive, that.caseSensitive)
&& Objects.equals(includeColumnStats, that.includeColumnStats)
&& Objects.equals(ignoreResiduals, that.ignoreResiduals)
&& Objects.equals(allowAnyClassToBeLoaded, that.allowAnyClassToBeLoaded)
&& Objects.equals(snapshotId, that.snapshotId)
&& Objects.equals(snapshotAsOfTime, that.snapshotAsOfTime)
&& Objects.equals(fromSnapshotId, that.fromSnapshotId)
Expand All @@ -138,8 +146,8 @@ public boolean equals(Object o) {

@Override
public int hashCode() {
return Objects.hash(properties, snapshot, caseSensitive, includeColumnStats,
ignoreResiduals, snapshotId, snapshotAsOfTime, fromSnapshotId, toSnapshotId);
return Objects.hash(properties, snapshot, caseSensitive, includeColumnStats, ignoreResiduals,
allowAnyClassToBeLoaded, snapshotId, snapshotAsOfTime, fromSnapshotId, toSnapshotId);
}

@JsonPOJOBuilder(withPrefix = "")
Expand All @@ -152,6 +160,8 @@ public static class IcebergFormatPluginConfigBuilder {

private Boolean ignoreResiduals;

private Boolean allowAnyClassToBeLoaded;

private Long snapshotId;

private Long snapshotAsOfTime;
Expand Down Expand Up @@ -180,6 +190,11 @@ public IcebergFormatPluginConfigBuilder ignoreResiduals(Boolean ignoreResiduals)
return this;
}

public IcebergFormatPluginConfigBuilder allowAnyClassToBeLoaded(Boolean allowAnyClassToBeLoaded) {
this.allowAnyClassToBeLoaded = allowAnyClassToBeLoaded;
return this;
}

public IcebergFormatPluginConfigBuilder snapshotId(Long snapshotId) {
this.snapshotId = snapshotId;
return this;
Expand Down