Skip to content

Commit

Permalink
Respect bigquery_opts and match output path to package
Browse files Browse the repository at this point in the history
  • Loading branch information
mfridman committed Mar 22, 2024
1 parent da36231 commit 1c13d85
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 16 deletions.
31 changes: 19 additions & 12 deletions internal/protoschema/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,11 @@ func Generate(input protoreflect.MessageDescriptor, opts ...GenerateOptions) (bi
"google.protobuf.ListValue",
))
generator := &bigQuerySchemaGenerator{
maxDepth: options.maxDepth,
maxRecursionDepth: options.maxRecursionDepth,
seen: make(map[protoreflect.FullName]int),
normalizer: normalizer,
maxDepth: options.maxDepth,
maxRecursionDepth: options.maxRecursionDepth,
generateAllMessages: options.generateAllMessages,
seen: make(map[protoreflect.FullName]int),
normalizer: normalizer,
}
if generator.maxDepth == 0 || generator.maxDepth > 15 {
generator.maxDepth = 15
Expand All @@ -121,10 +122,11 @@ func Generate(input protoreflect.MessageDescriptor, opts ...GenerateOptions) (bi
}

type bigQuerySchemaGenerator struct {
maxDepth int
maxRecursionDepth int
seen map[protoreflect.FullName]int
normalizer *normalize.Normalizer
maxDepth int
maxRecursionDepth int
generateAllMessages bool
seen map[protoreflect.FullName]int
normalizer *normalize.Normalizer
}

func (p *bigQuerySchemaGenerator) generate(msgDesc protoreflect.MessageDescriptor, depth int) (bigquery.Schema, error) {
Expand All @@ -140,7 +142,9 @@ func (p *bigQuerySchemaGenerator) generateFields(msgDesc protoreflect.MessageDes
if err != nil {
return nil, err
}

if msgOptions == nil && !p.generateAllMessages {
return nil, nil
}
msgPb, err := p.normalizer.FindDescriptorProto(msgDesc)
if err != nil {
return nil, err
Expand Down Expand Up @@ -174,9 +178,12 @@ func (p *bigQuerySchemaGenerator) generateFields(msgDesc protoreflect.MessageDes
return result, nil
}

func (p *bigQuerySchemaGenerator) generateField(msgOptions *bqproto.BigQueryMessageOptions,
fieldDesc protoreflect.FieldDescriptor, fieldPb *descriptorpb.FieldDescriptorProto,
depth int) (*bigquery.FieldSchema, error) {
func (p *bigQuerySchemaGenerator) generateField(
msgOptions *bqproto.BigQueryMessageOptions,
fieldDesc protoreflect.FieldDescriptor,
fieldPb *descriptorpb.FieldDescriptorProto,
depth int,
) (*bigquery.FieldSchema, error) {
if depth >= p.maxDepth {
return nil, nil
}
Expand Down
13 changes: 11 additions & 2 deletions internal/protoschema/bigquery/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,18 @@ func WithMaxRecursionDepth(maxRecursionDepth int) GenerateOptions {
})
}

// WithGenerateAllMessages returns a GenerateOptions that generates all messages, not just those
// with the extension option.
func WithGenerateAllMessages() GenerateOptions {
return generateOptionsFunc(func(options *generateOptions) {
options.generateAllMessages = true
})
}

type generateOptions struct {
maxDepth int
maxRecursionDepth int
maxDepth int
maxRecursionDepth int
generateAllMessages bool
}

type generateOptionsFunc func(*generateOptions)
Expand Down
29 changes: 27 additions & 2 deletions internal/protoschema/plugin/pluginbigquery/pluginbigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@ package pluginbigquery

import (
"context"
"fmt"
"path/filepath"
"strings"

bqproto "github.com/GoogleCloudPlatform/protoc-gen-bq-schema/protos"
"github.com/bufbuild/protoplugin"
"github.com/bufbuild/protoschema-plugins/internal/protoschema/bigquery"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
)

// Handle implements protoplugin.Handler and is the main entry point for the plugin.
Expand All @@ -36,6 +40,11 @@ func Handle(
for _, fileDescriptor := range fileDescriptors {
for i := range fileDescriptor.Messages().Len() {
messageDescriptor := fileDescriptor.Messages().Get(i)

tableName := tryGetTableNameFromOptions(messageDescriptor)
if tableName == "" {
tableName = string(messageDescriptor.Name())
}
schema, _, err := bigquery.Generate(messageDescriptor)
if err != nil {
return err
Expand All @@ -47,8 +56,10 @@ func Handle(
if len(data) == 0 || string(data) == "null" {
continue
}
name := tableName + "." + bigquery.FileExtension
filename := strings.ReplaceAll(string(fileDescriptor.Package()), ".", "/")
responseWriter.AddFile(
fmt.Sprintf("%s.%s", messageDescriptor.FullName(), bigquery.FileExtension),
filepath.Join(filename, name),
string(data),
)
}
Expand All @@ -57,3 +68,17 @@ func Handle(
responseWriter.SetFeatureProto3Optional()
return nil
}

func tryGetTableNameFromOptions(messageDescriptor protoreflect.MessageDescriptor) string {
if !proto.HasExtension(messageDescriptor.Options(), bqproto.E_BigqueryOpts) {
return ""
}
messageOptions, ok := proto.GetExtension(
messageDescriptor.Options(),
bqproto.E_BigqueryOpts,
).(*bqproto.BigQueryMessageOptions)
if !ok || messageOptions == nil {
return ""
}
return messageOptions.GetTableName()
}

0 comments on commit 1c13d85

Please sign in to comment.