diff --git a/cmd/struct.go b/cmd/struct.go index 5825120..f438f9a 100644 --- a/cmd/struct.go +++ b/cmd/struct.go @@ -16,6 +16,10 @@ var structCmd = &cobra.Command{ Run: structRun, } +const ( + repeat = " " +) + func init() { rootCmd.AddCommand(structCmd) } @@ -32,39 +36,64 @@ func structRun(cmd *cobra.Command, args []string) { } func printGoStruct(node *schema.GroupNode, w *os.File, depth int) { - indent := strings.Repeat("\t", depth) if depth == 0 { _, _ = w.WriteString(fmt.Sprintf("type %s struct {\n", node.Name())) } - + depth++ + indent := strings.Repeat(repeat, depth) for i := 0; i < node.NumFields(); i++ { field := node.Field(i) fieldName := field.Name() if group, ok := field.(*schema.GroupNode); ok { - // 对于嵌套结构,直接在字段定义中展开结构体 - _, _ = w.WriteString(fmt.Sprintf("%s%s struct {\n", - indent, toCamelCase(fieldName))) - // 递归处理嵌套字段 - for j := 0; j < group.NumFields(); j++ { - nestedField := group.Field(j) - nestedName := nestedField.Name() - nestedType := parquetTypeToGoType(nestedField) - _, _ = w.WriteString(fmt.Sprintf("%s\t%s %s `parquet:\"%s\"`\n", - indent, toCamelCase(nestedName), nestedType, nestedName)) + // 检查是否是简单的 List 结构 + if isSimpleList(group) { + elementType := getListElementType(group) + _, _ = fmt.Fprintf(w, "%s%s []%s `parquet:\"%s\"`\n", + indent, toCamelCase(fieldName), elementType, fieldName) + } else { + // 嵌套结构处理 + _, _ = fmt.Fprintf(w, "%s%s struct {\n", + indent, toCamelCase(fieldName)) + // 递归处理嵌套字段,增加缩进 + for j := 0; j < group.NumFields(); j++ { + nestedField := group.Field(j) + nestedName := nestedField.Name() + indent = strings.Repeat(repeat, depth+1) + nestedType := parquetTypeToGoType(nestedField) + _, _ = fmt.Fprintf(w, "%s%s %s `parquet:\"%s\"`\n", + indent, toCamelCase(nestedName), nestedType, nestedName) + indent = strings.Repeat(repeat, depth) + } + // 结束嵌套结构体定义 + _, _ = fmt.Fprintf(w, "%s} `parquet:\"%s\"`\n", + indent, fieldName) } - _, _ = w.WriteString(fmt.Sprintf("%s} `parquet:\"%s\"`\n", - indent, fieldName)) } else { goType := parquetTypeToGoType(field) - _, _ = w.WriteString(fmt.Sprintf("%s%s %s `parquet:\"%s\"`\n", - indent, toCamelCase(fieldName), goType, fieldName)) + _, _ = fmt.Fprintf(w, "%s%s %s `parquet:\"%s\"`\n", + indent, toCamelCase(fieldName), goType, fieldName) } } - + depth-- + indent = strings.Repeat(repeat, depth) if depth == 0 { - _, _ = w.WriteString("}\n") + _, _ = fmt.Fprintf(w, "%s}\n", indent) + } +} + +// 判断是否是简单的 List 结构(只包含一个 list 字段的结构) +func isSimpleList(group *schema.GroupNode) bool { + return group.NumFields() == 1 && + group.Field(0).Name() == "list" +} + +// 获取 List 的元素类型 +func getListElementType(group *schema.GroupNode) string { + if group.NumFields() == 1 { + return parquetTypeToGoType(group.Field(0)) } + return "any" } // 辅助函数:将 parquet 数据类型转换为 Go 类型 @@ -72,7 +101,7 @@ func parquetTypeToGoType(field schema.Node) string { logicalType := field.LogicalType().String() switch { // String - case strings.HasPrefix(logicalType, "String"): + case strings.HasPrefix(logicalType, "String") || strings.HasPrefix(logicalType, "string"): return "string" // Int fmt.Sprintf("Int(bitWidth=%d, isSigned=%t)", t.typ.GetBitWidth(), t.typ.GetIsSigned()) case strings.HasPrefix(logicalType, "Int"): @@ -88,32 +117,32 @@ func parquetTypeToGoType(field schema.Node) string { } return "uint32" // Decimal 格式为:fmt.Sprintf("Decimal(precision=%d, scale=%d)", t.typ.Precision, t.typ.Scale) - case strings.HasPrefix(logicalType, "Decimal"): + case strings.HasPrefix(logicalType, "Decimal") || strings.HasPrefix(logicalType, "decimal"): return "float64" // Date - case strings.HasPrefix(logicalType, "Date"): + case strings.HasPrefix(logicalType, "Date") || strings.HasPrefix(logicalType, "date"): return "time.Time" // Time fmt.Sprintf("Time(isAdjustedToUTC=%t, timeUnit=%s)", t.typ.GetIsAdjustedToUTC(), timeUnitToString(t.typ.GetUnit())) - case strings.HasPrefix(logicalType, "Time"): + case strings.HasPrefix(logicalType, "Time") || strings.HasPrefix(logicalType, "time"): return "time.Time" // Timestamp fmt.Sprintf("Timestamp(isAdjustedToUTC=%t, timeUnit=%s, is_from_converted_type=%t, force_set_converted_type=%t)",t.typ.GetIsAdjustedToUTC(), timeUnitToString(t.typ.GetUnit()), t.fromConverted, t.forceConverted) - case strings.HasPrefix(logicalType, "Timestamp"): + case strings.HasPrefix(logicalType, "Timestamp") || strings.HasPrefix(logicalType, "timestamp"): return "time.Time" - // Float16 - case strings.HasPrefix(logicalType, "Float"): + // Float16 + case strings.HasPrefix(logicalType, "Float") || strings.HasPrefix(logicalType, "float"): return "float32" - case strings.HasPrefix(logicalType, "Double"): + case strings.HasPrefix(logicalType, "Double") || strings.HasPrefix(logicalType, "double"): return "float64" - case strings.HasPrefix(logicalType, "Boolean"): + case strings.HasPrefix(logicalType, "Boolean") || strings.HasPrefix(logicalType, "boolean"): return "bool" - case strings.HasPrefix(logicalType, "Binary"): + case strings.HasPrefix(logicalType, "Binary") || strings.HasPrefix(logicalType, "binary"): return "[]byte" - case strings.HasPrefix(logicalType, "JSON"): + case strings.HasPrefix(logicalType, "JSON") || strings.HasPrefix(logicalType, "json"): return "json.RawMessage" - case strings.HasPrefix(logicalType, "UUID"): + case strings.HasPrefix(logicalType, "UUID") || strings.HasPrefix(logicalType, "uuid"): return "uuid.UUID" // github.com/google/uuid // List - case strings.HasPrefix(logicalType, "List"): + case strings.HasPrefix(logicalType, "List") || strings.HasPrefix(logicalType, "list"): // 如果是 List 类型,尝试获取元素类型 if listNode, ok := field.(*schema.GroupNode); ok && listNode.NumFields() > 0 { elementField := listNode.Field(0) @@ -122,21 +151,21 @@ func parquetTypeToGoType(field schema.Node) string { } return "[]any" // Map - case strings.HasPrefix(logicalType, "Map"): + case strings.HasPrefix(logicalType, "Map") || strings.HasPrefix(logicalType, "map"): return "map[string]any" - case strings.HasPrefix(logicalType, "Array"): + case strings.HasPrefix(logicalType, "Array") || strings.HasPrefix(logicalType, "array"): return "[]" - case strings.HasPrefix(logicalType, "Struct"): + case strings.HasPrefix(logicalType, "Struct") || strings.HasPrefix(logicalType, "struct"): return "struct" - case strings.HasPrefix(logicalType, "Enum"): + case strings.HasPrefix(logicalType, "Enum") || strings.HasPrefix(logicalType, "enum"): return "string" - case strings.Contains(logicalType, "Interval"): + case strings.Contains(logicalType, "Interval") || strings.Contains(logicalType, "interval"): return "time.Duration" - case strings.HasPrefix(logicalType, "Unknown"): + case strings.HasPrefix(logicalType, "Unknown") || strings.HasPrefix(logicalType, "unknown"): return "any" - case strings.HasPrefix(logicalType, "Null"): + case strings.HasPrefix(logicalType, "Null") || strings.HasPrefix(logicalType, "null"): return "any" - case strings.HasPrefix(logicalType, "None"): + case strings.HasPrefix(logicalType, "None") || strings.HasPrefix(logicalType, "none"): return "any" default: return "any" diff --git a/testdata/all_type.parquet b/testdata/all_type.parquet new file mode 100644 index 0000000..d8d6740 Binary files /dev/null and b/testdata/all_type.parquet differ