diff --git a/internal/protoschema/bigquery/bigquery.go b/internal/protoschema/bigquery/bigquery.go index 9731744..08c3a01 100644 --- a/internal/protoschema/bigquery/bigquery.go +++ b/internal/protoschema/bigquery/bigquery.go @@ -98,10 +98,11 @@ func Generate(input protoreflect.MessageDescriptor, opts ...GenerateOptions) (bi "google.protobuf.ListValue", )) generator := &bigQuerySchemaGenerator{ - maxDepth: options.maxDepth, - maxRecursionDepth: options.maxRecursionDepth, - seen: make(map[protoreflect.FullName]int), - normalizer: normalizer, + maxDepth: options.maxDepth, + maxRecursionDepth: options.maxRecursionDepth, + generateAllMessages: options.generateAllMessages, + seen: make(map[protoreflect.FullName]int), + normalizer: normalizer, } if generator.maxDepth == 0 || generator.maxDepth > 15 { generator.maxDepth = 15 @@ -121,10 +122,11 @@ func Generate(input protoreflect.MessageDescriptor, opts ...GenerateOptions) (bi } type bigQuerySchemaGenerator struct { - maxDepth int - maxRecursionDepth int - seen map[protoreflect.FullName]int - normalizer *normalize.Normalizer + maxDepth int + maxRecursionDepth int + generateAllMessages bool + seen map[protoreflect.FullName]int + normalizer *normalize.Normalizer } func (p *bigQuerySchemaGenerator) generate(msgDesc protoreflect.MessageDescriptor, depth int) (bigquery.Schema, error) { @@ -140,7 +142,9 @@ func (p *bigQuerySchemaGenerator) generateFields(msgDesc protoreflect.MessageDes if err != nil { return nil, err } - + if msgOptions == nil && !p.generateAllMessages { + return nil, nil + } msgPb, err := p.normalizer.FindDescriptorProto(msgDesc) if err != nil { return nil, err @@ -174,9 +178,12 @@ func (p *bigQuerySchemaGenerator) generateFields(msgDesc protoreflect.MessageDes return result, nil } -func (p *bigQuerySchemaGenerator) generateField(msgOptions *bqproto.BigQueryMessageOptions, - fieldDesc protoreflect.FieldDescriptor, fieldPb *descriptorpb.FieldDescriptorProto, - depth int) (*bigquery.FieldSchema, error) { +func (p *bigQuerySchemaGenerator) generateField( + msgOptions *bqproto.BigQueryMessageOptions, + fieldDesc protoreflect.FieldDescriptor, + fieldPb *descriptorpb.FieldDescriptorProto, + depth int, +) (*bigquery.FieldSchema, error) { if depth >= p.maxDepth { return nil, nil } diff --git a/internal/protoschema/bigquery/options.go b/internal/protoschema/bigquery/options.go index 3b7c96b..1758018 100644 --- a/internal/protoschema/bigquery/options.go +++ b/internal/protoschema/bigquery/options.go @@ -33,9 +33,18 @@ func WithMaxRecursionDepth(maxRecursionDepth int) GenerateOptions { }) } +// WithGenerateAllMessages returns a GenerateOptions that generates all messages, not just those +// with the extension option. +func WithGenerateAllMessages() GenerateOptions { + return generateOptionsFunc(func(options *generateOptions) { + options.generateAllMessages = true + }) +} + type generateOptions struct { - maxDepth int - maxRecursionDepth int + maxDepth int + maxRecursionDepth int + generateAllMessages bool } type generateOptionsFunc func(*generateOptions) diff --git a/internal/protoschema/plugin/pluginbigquery/pluginbigquery.go b/internal/protoschema/plugin/pluginbigquery/pluginbigquery.go index 2a9384d..864d55c 100644 --- a/internal/protoschema/plugin/pluginbigquery/pluginbigquery.go +++ b/internal/protoschema/plugin/pluginbigquery/pluginbigquery.go @@ -16,10 +16,14 @@ package pluginbigquery import ( "context" - "fmt" + "path/filepath" + "strings" + bqproto "github.com/GoogleCloudPlatform/protoc-gen-bq-schema/protos" "github.com/bufbuild/protoplugin" "github.com/bufbuild/protoschema-plugins/internal/protoschema/bigquery" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" ) // Handle implements protoplugin.Handler and is the main entry point for the plugin. @@ -36,6 +40,11 @@ func Handle( for _, fileDescriptor := range fileDescriptors { for i := range fileDescriptor.Messages().Len() { messageDescriptor := fileDescriptor.Messages().Get(i) + + tableName := tryGetTableNameFromOptions(messageDescriptor) + if tableName == "" { + tableName = string(messageDescriptor.Name()) + } schema, _, err := bigquery.Generate(messageDescriptor) if err != nil { return err @@ -47,8 +56,10 @@ func Handle( if len(data) == 0 || string(data) == "null" { continue } + name := tableName + "." + bigquery.FileExtension + filename := strings.ReplaceAll(string(fileDescriptor.Package()), ".", "/") responseWriter.AddFile( - fmt.Sprintf("%s.%s", messageDescriptor.FullName(), bigquery.FileExtension), + filepath.Join(filename, name), string(data), ) } @@ -57,3 +68,17 @@ func Handle( responseWriter.SetFeatureProto3Optional() return nil } + +func tryGetTableNameFromOptions(messageDescriptor protoreflect.MessageDescriptor) string { + if !proto.HasExtension(messageDescriptor.Options(), bqproto.E_BigqueryOpts) { + return "" + } + messageOptions, ok := proto.GetExtension( + messageDescriptor.Options(), + bqproto.E_BigqueryOpts, + ).(*bqproto.BigQueryMessageOptions) + if !ok || messageOptions == nil { + return "" + } + return messageOptions.GetTableName() +}