Skip to content

Commit

Permalink
[Bug][Translation][Spark] Fix SeaTunnelRowConvertor fail to convert w…
Browse files Browse the repository at this point in the history
…hen schema contains row type. (#5170)
  • Loading branch information
CheneyYin authored Jul 31, 2023
1 parent e4f666f commit 1db9f45
Show file tree
Hide file tree
Showing 11 changed files with 71 additions and 7 deletions.
2 changes: 2 additions & 0 deletions release-note.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
- [Core] [API] Fix parse nested row data type key changed upper (#4459)
- [Starter][Flink]Support transform-v2 for flink #3396
- [Flink] Support flink 1.14.x #3963
- [Core][Translation][Spark] Fix SeaTunnelRowConvertor fail to convert when schema contains row type (#5170)

### Transformer
- [Spark] Support transform-v2 for spark (#3409)
- [ALL]Add FieldMapper Transform #3781
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ public Row next() {
return null;
}
seaTunnelRow = outputRowConverter.convert(seaTunnelRow);

return new GenericRowWithSchema(seaTunnelRow.getFields(), structType);
} catch (Exception e) {
throw new TaskExecuteException("Row convert failed, caused: " + e.getMessage(), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ source {
fields {
id = "int"
name = "string"
c_row = {
c_row = {
c_int = int
}
}
}
}
}
Expand All @@ -49,6 +54,7 @@ transform {
id_1 = "id"
name2 = "name"
name3 = "name"
c_row_1 = "c_row"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ source {
id = "int"
name = "string"
age = "int"
c_row = {
c_row = {
c_int = int
}
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ source {
id = "int"
name = "string"
age = "int"
c_row = {
c_row = {
c_int = int
}
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ source {
id = "int"
name = "string"
age = "int"
c_row = {
c_row = {
c_int = int
}
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ source {
id = "int"
name = "string"
age = "int"
c_row = {
c_row = {
c_int = int
}
}
}
}
}
Expand All @@ -40,7 +45,7 @@ transform {
Filter {
source_table_name = "fake"
result_table_name = "fake1"
fields = ["age", "name"]
fields = ["age", "name", "c_row"]
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ source {
id = "int"
name = "string"
age = "int"
c_row = {
c_row = {
c_int = int
}
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ source {
string1 = "string"
int1 = "int"
c_bigint = "bigint"
c_row = {
c_row = {
c_int = int
}
}
}
}
}
Expand All @@ -48,6 +53,7 @@ transform {
age = age_as
int1 = int1_as
name = name
c_row = c_row
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ source {
c_map = "map<string, string>"
c_array = "array<int>"
c_decimal = "decimal(30, 8)"
c_row = {
c_row = {
c_int = int
}
}
}
}
}
Expand All @@ -46,7 +51,7 @@ transform {
source_table_name = "fake"
result_table_name = "fake1"
# the query table name must same as field 'source_table_name'
query = "select id, regexp_replace(name, '.+', 'b') as name, age+1 as age, pi() as pi, c_timestamp, c_date, c_map, c_array, c_decimal from fake"
query = "select id, regexp_replace(name, '.+', 'b') as name, age+1 as age, pi() as pi, c_timestamp, c_date, c_map, c_array, c_decimal, c_row from fake"
}
# The SQL transform support base function and criteria operation
# But the complex SQL unsupported yet, include: multi source table/rows JOIN and AGGREGATE operation and the like
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.translation.serialization.RowConverter;
import org.apache.seatunnel.translation.spark.utils.TypeConverterUtils;

import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.types.UTF8String;

import scala.Tuple2;
Expand All @@ -51,7 +54,11 @@ public SeaTunnelRowConverter(SeaTunnelDataType<?> dataType) {
@Override
public SeaTunnelRow convert(SeaTunnelRow seaTunnelRow) throws IOException {
validate(seaTunnelRow);
return (SeaTunnelRow) convert(seaTunnelRow, dataType);
GenericRowWithSchema rowWithSchema = (GenericRowWithSchema) convert(seaTunnelRow, dataType);
SeaTunnelRow newRow = new SeaTunnelRow(rowWithSchema.values());
newRow.setRowKind(seaTunnelRow.getRowKind());
newRow.setTableId(seaTunnelRow.getTableId());
return newRow;
}

private Object convert(Object field, SeaTunnelDataType<?> dataType) {
Expand All @@ -62,7 +69,7 @@ private Object convert(Object field, SeaTunnelDataType<?> dataType) {
case ROW:
SeaTunnelRow seaTunnelRow = (SeaTunnelRow) field;
SeaTunnelRowType rowType = (SeaTunnelRowType) dataType;
return convert(seaTunnelRow, rowType);
return convertRow(seaTunnelRow, rowType);
case DATE:
return Date.valueOf((LocalDate) field);
case TIMESTAMP:
Expand Down Expand Up @@ -94,16 +101,17 @@ private Object convert(Object field, SeaTunnelDataType<?> dataType) {
}
}

private SeaTunnelRow convert(SeaTunnelRow seaTunnelRow, SeaTunnelRowType rowType) {
private GenericRowWithSchema convertRow(SeaTunnelRow seaTunnelRow, SeaTunnelRowType rowType) {
int arity = rowType.getTotalFields();
Object[] values = new Object[arity];
StructType schema = (StructType) TypeConverterUtils.convert(rowType);
for (int i = 0; i < arity; i++) {
Object fieldValue = convert(seaTunnelRow.getField(i), rowType.getFieldType(i));
if (fieldValue != null) {
values[i] = fieldValue;
}
}
return new SeaTunnelRow(values);
return new GenericRowWithSchema(values, schema);
}

private scala.collection.immutable.HashMap<Object, Object> convertMap(
Expand Down Expand Up @@ -148,6 +156,10 @@ private Object reconvert(Object field, SeaTunnelDataType<?> dataType) {
}
switch (dataType.getSqlType()) {
case ROW:
if (field instanceof GenericRowWithSchema) {
return createFromGenericRow(
(GenericRowWithSchema) field, (SeaTunnelRowType) dataType);
}
return reconvert((SeaTunnelRow) field, (SeaTunnelRowType) dataType);
case DATE:
return ((Date) field).toLocalDate();
Expand All @@ -166,6 +178,15 @@ private Object reconvert(Object field, SeaTunnelDataType<?> dataType) {
}
}

private SeaTunnelRow createFromGenericRow(GenericRowWithSchema row, SeaTunnelRowType type) {
Object[] fields = row.values();
Object[] newFields = new Object[fields.length];
for (int idx = 0; idx < fields.length; idx++) {
newFields[idx] = reconvert(fields[idx], type.getFieldType(idx));
}
return new SeaTunnelRow(newFields);
}

private SeaTunnelRow reconvert(SeaTunnelRow engineRow, SeaTunnelRowType rowType) {
int num = engineRow.getFields().length;
Object[] fields = new Object[num];
Expand Down

0 comments on commit 1db9f45

Please sign in to comment.