diff --git a/internal/tests/integration/sql_test.go b/internal/tests/integration/sql_test.go index 28c32d4..88d7182 100644 --- a/internal/tests/integration/sql_test.go +++ b/internal/tests/integration/sql_test.go @@ -69,6 +69,48 @@ func TestIntegration_Schema(t *testing.T) { assert.Equal(t, types.LongType{}, schema.Fields[0].DataType) } +func TestIntegration_StructConversion(t *testing.T) { + ctx := context.Background() + spark, err := sql.NewSessionBuilder().Remote("sc://localhost").Build(ctx) + if err != nil { + t.Fatal(err) + } + + query := ` + select named_struct( + 'a', 1, + 'b', 2, + 'c', cast(10.32 as double), + 'd', array(1, 2, 3, 4) + ) struct_col + ` + df, err := spark.Sql(ctx, query) + assert.NoError(t, err) + res, err := df.Collect(ctx) + assert.NoError(t, err) + assert.Equal(t, 1, len(res)) + + columnData := res[0].Values()[0] + assert.NotNil(t, columnData) + structDataMap, ok := columnData.(map[string]any) + assert.True(t, ok) + + assert.Contains(t, structDataMap, "a") + assert.Contains(t, structDataMap, "b") + assert.Contains(t, structDataMap, "c") + assert.Contains(t, structDataMap, "d") + + assert.Equal(t, int32(1), structDataMap["a"]) + assert.Equal(t, int32(2), structDataMap["b"]) + assert.Equal(t, float64(10.32), structDataMap["c"]) + arrayData := []any{int32(1), int32(2), int32(3), int32(4)} + assert.Equal(t, arrayData, structDataMap["d"]) + + schema, err := df.Schema(ctx) + assert.NoError(t, err) + assert.Equal(t, "struct_col", schema.Fields[0].Name) +} + func TestMain(m *testing.M) { envShouldStartService := os.Getenv("START_SPARK_CONNECT_SERVICE") shouldStartService := envShouldStartService == "" || envShouldStartService == "1" diff --git a/spark/sql/types/arrow.go b/spark/sql/types/arrow.go index 2921881..47be9aa 100644 --- a/spark/sql/types/arrow.go +++ b/spark/sql/types/arrow.go @@ -260,6 +260,27 @@ func readArrayData(t arrow.Type, data arrow.ArrayData) ([]any, error) { } buf = append(buf, tmp) } + case arrow.STRUCT: + data := array.NewStructData(data) + schema := data.DataType().(*arrow.StructType) + + for i := 0; i < data.Len(); i++ { + if data.IsNull(i) { + buf = append(buf, nil) + continue + } + tmp := make(map[string]any) + + for j := range data.NumField() { + field := data.Field(j) + fieldValues, err := readArrayData(field.DataType().ID(), field.Data()) + if err != nil { + return nil, err + } + tmp[schema.Field(j).Name] = fieldValues[i] + } + buf = append(buf, tmp) + } default: return nil, fmt.Errorf("unsupported arrow data type %s", t.String()) } diff --git a/spark/sql/types/arrow_test.go b/spark/sql/types/arrow_test.go index 7c2b925..8ec9fbb 100644 --- a/spark/sql/types/arrow_test.go +++ b/spark/sql/types/arrow_test.go @@ -137,6 +137,22 @@ func TestReadArrowRecord(t *testing.T) { Name: "map_string_int32", Type: arrow.MapOf(arrow.BinaryTypes.String, arrow.PrimitiveTypes.Int32), }, + { + Name: "struct", + Type: arrow.StructOf( + arrow.Field{Name: "field1", Type: arrow.PrimitiveTypes.Int32}, + arrow.Field{Name: "field2", Type: arrow.BinaryTypes.String}, + ), + }, + { + Name: "nested_struct", + Type: arrow.StructOf( + arrow.Field{Name: "field1", Type: arrow.StructOf( + arrow.Field{Name: "nested_field1", Type: arrow.PrimitiveTypes.Int32}, + arrow.Field{Name: "nested_field2", Type: arrow.BinaryTypes.String}, + )}, + ), + }, } arrowSchema := arrow.NewSchema(arrowFields, nil) var buf bytes.Buffer @@ -224,6 +240,30 @@ func TestReadArrowRecord(t *testing.T) { mb.KeyBuilder().(*array.StringBuilder).Append("key2") mb.ItemBuilder().(*array.Int32Builder).Append(2) + i++ + sb := recordBuilder.Field(i).(*array.StructBuilder) + sb.Append(true) + sb.FieldBuilder(0).(*array.Int32Builder).Append(1) + sb.FieldBuilder(1).(*array.StringBuilder).Append("str1") + + sb.Append(true) + sb.FieldBuilder(0).(*array.Int32Builder).Append(2) + sb.FieldBuilder(1).(*array.StringBuilder).Append("str2") + + i++ + sb = recordBuilder.Field(i).(*array.StructBuilder) + sb.Append(true) + nsb := sb.FieldBuilder(0).(*array.StructBuilder) + nsb.Append(true) + nsb.FieldBuilder(0).(*array.Int32Builder).Append(1) + nsb.FieldBuilder(1).(*array.StringBuilder).Append("str1_nested") + + sb.Append(true) + nsb = sb.FieldBuilder(0).(*array.StructBuilder) + nsb.Append(true) + nsb.FieldBuilder(0).(*array.Int32Builder).Append(2) + nsb.FieldBuilder(1).(*array.StringBuilder).Append("str2_nested") + record := recordBuilder.NewRecord() defer record.Release() @@ -239,6 +279,13 @@ func TestReadArrowRecord(t *testing.T) { arrow.Timestamp(1686981953115000), arrow.Date64(1686981953117000), []any{int64(1), int64(-999231)}, map[any]any{"key1": int32(1)}, + map[string]any{"field1": int32(1), "field2": "str1"}, + map[string]any{ + "field1": map[string]any{ + "nested_field1": int32(1), + "nested_field2": "str1_nested", + }, + }, }, values[0].Values()) assert.Equal(t, []any{ @@ -249,6 +296,13 @@ func TestReadArrowRecord(t *testing.T) { arrow.Timestamp(1686981953116000), arrow.Date64(1686981953118000), []any{int64(1), int64(2), int64(3)}, map[any]any{"key2": int32(2)}, + map[string]any{"field1": int32(2), "field2": "str2"}, + map[string]any{ + "field1": map[string]any{ + "nested_field1": int32(2), + "nested_field2": "str2_nested", + }, + }, }, values[1].Values()) }