Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Fix ArrayIndexOutOfBoundsException when updating nested fields with null values. #4156

Merged
merged 3 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,20 +71,19 @@ String name() {

@Override
public Object agg(Object accumulator, Object inputField) {
if (accumulator == null || inputField == null) {
return accumulator == null ? inputField : accumulator;
if (accumulator == null) {
return inputField;
}
if (inputField == null) {
return accumulator;
}

InternalArray acc = (InternalArray) accumulator;
InternalArray input = (InternalArray) inputField;

List<InternalRow> rows = new ArrayList<>();
for (int i = 0; i < acc.size(); i++) {
rows.add(acc.getRow(i, nestedFields));
}
for (int i = 0; i < input.size(); i++) {
rows.add(input.getRow(i, nestedFields));
}
List<InternalRow> rows = new ArrayList<>(acc.size() + input.size());
addNonNullRows(acc, rows);
addNonNullRows(input, rows);

if (keyProjection != null) {
Map<BinaryRow, InternalRow> map = new HashMap<>();
Expand All @@ -111,10 +110,11 @@ public Object retract(Object accumulator, Object retractField) {
if (keyProjection == null) {
checkNotNull(elementEqualiser);
List<InternalRow> rows = new ArrayList<>();
for (int i = 0; i < acc.size(); i++) {
rows.add(acc.getRow(i, nestedFields));
}
addNonNullRows(acc, rows);
for (int i = 0; i < retract.size(); i++) {
if (retract.isNullAt(i)) {
continue;
}
InternalRow retractRow = retract.getRow(i, nestedFields);
rows.removeIf(next -> elementEqualiser.equals(next, retractRow));
}
Expand All @@ -123,15 +123,30 @@ public Object retract(Object accumulator, Object retractField) {
Map<BinaryRow, InternalRow> map = new HashMap<>();

for (int i = 0; i < acc.size(); i++) {
if (acc.isNullAt(i)) {
continue;
}
InternalRow row = acc.getRow(i, nestedFields);
map.put(keyProjection.apply(row).copy(), row);
}

for (int i = 0; i < retract.size(); i++) {
if (retract.isNullAt(i)) {
continue;
}
map.remove(keyProjection.apply(retract.getRow(i, nestedFields)));
}

return new GenericArray(new ArrayList<>(map.values()).toArray());
}
}

private void addNonNullRows(InternalArray array, List<InternalRow> rows) {
for (int i = 0; i < array.size(); i++) {
if (array.isNullAt(i)) {
continue;
}
rows.add(array.getRow(i, nestedFields));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

Expand Down Expand Up @@ -1321,6 +1322,39 @@ public void testUseCase() {
Row.of(3, "Liu", "NanJing", 1, "12-26", "Cup", 30L));
}

@Test
public void testUseCaseWithNullValue() {
sql(
"INSERT INTO order_wide\n"
+ "SELECT 6, CAST (NULL AS STRING), CAST (NULL AS STRING), "
+ "ARRAY[cast(null as ROW<daily_id INT, today STRING, product_name STRING, price BIGINT>)]");

List<Row> result =
sql("SELECT * FROM order_wide").stream()
.sorted(Comparator.comparingInt(r -> r.getFieldAs(0)))
.collect(Collectors.toList());

assertThat(checkOneRecord(result.get(0), 6, null, null, (Row) null)).isTrue();

sql(
"INSERT INTO order_wide\n"
+ "SELECT 6, 'Sun', CAST (NULL AS STRING), "
+ "ARRAY[ROW(1, '01-01','Apple', 6999)]");

result =
sql("SELECT * FROM order_wide").stream()
.sorted(Comparator.comparingInt(r -> r.getFieldAs(0)))
.collect(Collectors.toList());
assertThat(
checkOneRecord(
result.get(0),
6,
"Sun",
null,
Row.of(1, "01-01", "Apple", 6999L)))
.isTrue();
}

@Test
public void testUseCaseAppend() {
sql(
Expand Down Expand Up @@ -1429,10 +1463,10 @@ private boolean checkOneRecord(
if ((int) record.getField(0) != orderId) {
return false;
}
if (!record.getFieldAs(1).equals(userName)) {
if (!Objects.equals(record.getFieldAs(1), userName)) {
return false;
}
if (!record.getFieldAs(2).equals(address)) {
if (!Objects.equals(record.getFieldAs(2), address)) {
return false;
}

Expand All @@ -1455,7 +1489,7 @@ private boolean checkNestedTable(Row[] nestedTable, Row... subOrders) {
Arrays.stream(subOrders).sorted(comparator).collect(Collectors.toList());

for (int i = 0; i < sortedActual.size(); i++) {
if (!sortedActual.get(i).equals(sortedExpected.get(i))) {
if (!Objects.equals(sortedActual.get(i), sortedExpected.get(i))) {
return false;
}
}
Expand Down
Loading