Skip to content

Commit

Permalink
Refactor of the proto description logic for safety (#326)
Browse files Browse the repository at this point in the history
Refactor of the proto description logic for safety
  • Loading branch information
TristonianJones authored Mar 26, 2020
1 parent 74ccfea commit 525447a
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 80 deletions.
29 changes: 19 additions & 10 deletions common/types/pb/enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,28 @@ import (
descpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
)

// EnumDescription maps a qualified enum name to its numeric value.
type EnumDescription struct {
enumName string
file *FileDescription
desc *descpb.EnumValueDescriptorProto
// NewEnumValueDescription produces an enum value description with the fully qualified enum value
// name and the enum value descriptor.
func NewEnumValueDescription(name string,
desc *descpb.EnumValueDescriptorProto) *EnumValueDescription {
return &EnumValueDescription{
enumValueName: name,
desc: desc,
}
}

// Name of the enum.
func (ed *EnumDescription) Name() string {
return ed.enumName
// EnumValueDescription maps a fully-qualified enum value name to its numeric value.
type EnumValueDescription struct {
enumValueName string
desc *descpb.EnumValueDescriptorProto
}

// Value (numeric) of the enum.
func (ed *EnumDescription) Value() int32 {
// Name returns the fully-qualified identifier name for the enum value.
func (ed *EnumValueDescription) Name() string {
return ed.enumValueName
}

// Value returns the (numeric) value of the enum.
func (ed *EnumValueDescription) Value() int32 {
return ed.desc.GetNumber()
}
108 changes: 73 additions & 35 deletions common/types/pb/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,38 @@ package pb

import (
"fmt"
"sync/atomic"

descpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
)

// FileDescription holds a map of all types and enums declared within a .proto
// file.
// NewFileDescription returns a FileDescription instance with a complete listing of all the message
// types and enum values declared within any scope in the file.
func NewFileDescription(fileDesc *descpb.FileDescriptorProto, pbdb *Db) *FileDescription {
isProto3 := fileDesc.GetSyntax() == "proto3"
metadata := collectFileMetadata(fileDesc)
enums := make(map[string]*EnumValueDescription)
for name, enumVal := range metadata.enumValues {
enums[name] = NewEnumValueDescription(name, enumVal)
}
types := make(map[string]*TypeDescription)
for name, msgType := range metadata.msgTypes {
types[name] = NewTypeDescription(name, msgType, isProto3, pbdb.DescribeType)
}
return &FileDescription{
types: types,
enums: enums,
}
}

// FileDescription holds a map of all types and enum values declared within a proto file.
type FileDescription struct {
pbdb *Db
desc *descpb.FileDescriptorProto
types map[string]*TypeDescription
enums map[string]*EnumDescription
enums map[string]*EnumValueDescription
}

// GetEnumDescription returns an EnumDescription for a qualified enum value
// name declared within the .proto file.
func (fd *FileDescription) GetEnumDescription(enumName string) (*EnumDescription, error) {
func (fd *FileDescription) GetEnumDescription(enumName string) (*EnumValueDescription, error) {
if ed, found := fd.enums[sanitizeProtoName(enumName)]; found {
return ed, nil
}
Expand Down Expand Up @@ -70,44 +85,67 @@ func (fd *FileDescription) GetTypeNames() []string {
return typeNames
}

// Package returns the file's qualified package name.
func (fd *FileDescription) Package() string {
return fd.desc.GetPackage()
// sanitizeProtoName strips the leading '.' from the proto message name.
func sanitizeProtoName(name string) string {
if name != "" && name[0] == '.' {
return name[1:]
}
return name
}

// fileMetadata is a flattened view of message types and enum values within a file descriptor.
type fileMetadata struct {
// msgTypes maps from fully-qualified message name to descriptor.
msgTypes map[string]*descpb.DescriptorProto
// enumValues maps from fully-qualified enum value to enum value descriptor.
enumValues map[string]*descpb.EnumValueDescriptorProto
}

func (fd *FileDescription) indexEnums(pkg string, enumTypes []*descpb.EnumDescriptorProto) {
for _, enumType := range enumTypes {
for _, enumValue := range enumType.Value {
enumValueName := fmt.Sprintf(
"%s.%s.%s", pkg, enumType.GetName(), enumValue.GetName())
fd.enums[enumValueName] = &EnumDescription{
enumName: enumValueName,
file: fd,
desc: enumValue}
fd.pbdb.revFileDescriptorMap[enumValueName] = fd
// collectFileMetadata traverses the proto file object graph to collect message types and enum
// values and index them by their fully qualified names.
func collectFileMetadata(fileDesc *descpb.FileDescriptorProto) *fileMetadata {
pkg := fileDesc.GetPackage()
msgTypes := make(map[string]*descpb.DescriptorProto)
collectMsgTypes(pkg, fileDesc.GetMessageType(), msgTypes)
enumValues := make(map[string]*descpb.EnumValueDescriptorProto)
collectEnumValues(pkg, fileDesc.GetEnumType(), enumValues)
for container, msgType := range msgTypes {
nestedEnums := msgType.GetEnumType()
if len(nestedEnums) == 0 {
continue
}
collectEnumValues(container, nestedEnums, enumValues)
}
return &fileMetadata{
msgTypes: msgTypes,
enumValues: enumValues,
}
}

func (fd *FileDescription) indexTypes(pkg string, msgTypes []*descpb.DescriptorProto) {
// collectMsgTypes recursively collects messages and nested messages into a map of fully
// qualified message names to message descriptors.
func collectMsgTypes(container string,
msgTypes []*descpb.DescriptorProto,
msgTypeMap map[string]*descpb.DescriptorProto) {
for _, msgType := range msgTypes {
msgName := fmt.Sprintf("%s.%s", pkg, msgType.GetName())
td := &TypeDescription{
typeName: msgName,
file: fd,
desc: msgType,
metadata: &atomic.Value{},
msgName := fmt.Sprintf("%s.%s", container, msgType.GetName())
msgTypeMap[msgName] = msgType
nestedTypes := msgType.GetNestedType()
if len(nestedTypes) == 0 {
continue
}
fd.types[msgName] = td
fd.indexTypes(msgName, msgType.NestedType)
fd.indexEnums(msgName, msgType.EnumType)
fd.pbdb.revFileDescriptorMap[msgName] = fd
collectMsgTypes(msgName, nestedTypes, msgTypeMap)
}
}

func sanitizeProtoName(name string) string {
if name != "" && name[0] == '.' {
return name[1:]
// collectEnumValues accumulates the enum values within an enum declaration.
func collectEnumValues(container string,
enumTypes []*descpb.EnumDescriptorProto,
enumValueMap map[string]*descpb.EnumValueDescriptorProto) {
for _, enumType := range enumTypes {
for _, enumValue := range enumType.GetValue() {
name := fmt.Sprintf("%s.%s.%s", container, enumType.GetName(), enumValue.GetName())
enumValueMap[name] = enumValue
}
}
return name
}
3 changes: 0 additions & 3 deletions common/types/pb/file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ func TestFileDescription_GetTypes(t *testing.T) {
if td.Name() != typeName {
t.Error("Indexed type name not equal to descriptor type name", td, typeName)
}
if td.file != fd {
t.Error("Indexed type does not refer to current file", td)
}
}
}

Expand Down
36 changes: 18 additions & 18 deletions common/types/pb/pb.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ func NewDb() *Db {
pbdb := &Db{
revFileDescriptorMap: make(map[string]*FileDescription),
}
// The FileDescription objects in the default db contain lazily initialized TypeDescription
// values which may point to the state contained in the DefaultDb irrespective of this shallow
// copy; however, the type graph for a field is idempotently computed, and is guaranteed to
// only be initialized once thanks to atomic values within the TypeDescription objects, so it
// is safe to share these values across instances.
for k, v := range DefaultDb.revFileDescriptorMap {
pbdb.revFileDescriptorMap[k] = v
}
Expand All @@ -63,14 +68,14 @@ func (pbdb *Db) RegisterDescriptor(fileDesc *descpb.FileDescriptorProto) (*FileD
if found {
return fd, nil
}
fd, err := pbdb.describeFileInternal(fileDesc)
if err != nil {
return nil, err
fd = NewFileDescription(fileDesc, pbdb)
for _, enumValName := range fd.GetEnumNames() {
pbdb.revFileDescriptorMap[enumValName] = fd
}
for _, msgTypeName := range fd.GetTypeNames() {
pbdb.revFileDescriptorMap[msgTypeName] = fd
}
pbdb.revFileDescriptorMap[fileDesc.GetName()] = fd
pkg := fd.Package()
fd.indexTypes(pkg, fileDesc.MessageType)
fd.indexEnums(pkg, fileDesc.EnumType)

// Return the specific file descriptor registered.
return fd, nil
Expand Down Expand Up @@ -98,7 +103,7 @@ func (pbdb *Db) DescribeFile(message proto.Message) (*FileDescription, error) {

// DescribeEnum takes a qualified enum name and returns an `EnumDescription` if it exists in the
// `pb.Db`.
func (pbdb *Db) DescribeEnum(enumName string) (*EnumDescription, error) {
func (pbdb *Db) DescribeEnum(enumName string) (*EnumValueDescription, error) {
enumName = sanitizeProtoName(enumName)
if fd, found := pbdb.revFileDescriptorMap[enumName]; found {
return fd.GetEnumDescription(enumName)
Expand Down Expand Up @@ -131,7 +136,7 @@ func CollectFileDescriptorSet(message proto.Message) (*descpb.FileDescriptorSet,
if _, found := fdMap[dep]; found {
continue
}
depDesc, err := fileDescriptor(dep)
depDesc, err := readFileDescriptor(dep)
if err != nil {
return nil, err
}
Expand All @@ -150,16 +155,11 @@ func CollectFileDescriptorSet(message proto.Message) (*descpb.FileDescriptorSet,
}, nil
}

func (pbdb *Db) describeFileInternal(fileDesc *descpb.FileDescriptorProto) (*FileDescription, error) {
fd := &FileDescription{
pbdb: pbdb,
desc: fileDesc,
types: make(map[string]*TypeDescription),
enums: make(map[string]*EnumDescription)}
return fd, nil
}

func fileDescriptor(protoFileName string) (*descpb.FileDescriptorProto, error) {
// readFileDescriptor will read the gzipped file descriptor for a given proto file and return the
// hydrated FileDescriptorProto.
//
// If the file name is not found or there is an error during deserialization an error is returned.
func readFileDescriptor(protoFileName string) (*descpb.FileDescriptorProto, error) {
gzipped := proto.FileDescriptor(protoFileName)
r, err := gzip.NewReader(bytes.NewReader(gzipped))
if err != nil {
Expand Down
54 changes: 40 additions & 14 deletions common/types/pb/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,50 @@ import (
"fmt"
"reflect"
"strings"
"sync/atomic"
"sync"

"github.com/golang/protobuf/proto"
descpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
structpb "github.com/golang/protobuf/ptypes/struct"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)

// NewTypeDescription produces a TypeDescription value for the fully-qualified proto type name
// with a given descriptor.
//
// The type description creation method also expects the type to be marked clearly as a proto2 or
// proto3 type, and accepts a typeResolver reference for resolving field TypeDescription during
// lazily initialization of the type which is done atomically.
func NewTypeDescription(typeName string, desc *descpb.DescriptorProto,
isProto3 bool, resolveType typeResolver) *TypeDescription {
return &TypeDescription{
typeName: typeName,
isProto3: isProto3,
desc: desc,
resolveType: resolveType,
}
}

// TypeDescription is a collection of type metadata relevant to expression
// checking and evaluation.
type TypeDescription struct {
typeName string
file *FileDescription
isProto3 bool
desc *descpb.DescriptorProto
metadata *atomic.Value

// resolveType is used to lookup field types during type initialization.
// The resolver may point to shared state; however, this state is guaranteed to be computed at
// most one time.
resolveType typeResolver
init sync.Once
metadata *typeMetadata
}

// typeResolver accepts a type name and returns a TypeDescription.
// The typeResolver is used to resolve field types during lazily initialization of the type
// description metadata.
type typeResolver func(typeName string) (*TypeDescription, error)

type typeMetadata struct {
fields map[string]*FieldDescription // fields by name (proto)
fieldIndices map[int][]*FieldDescription // fields by Go struct idx
Expand Down Expand Up @@ -80,14 +107,14 @@ func (td *TypeDescription) DefaultValue() proto.Message {
return val.(proto.Message)
}

// getMetadata computes the type field metadata used for determining field types and default
// values. The call to makeMetadata within this method is guaranteed to be invoked exactly
// once.
func (td *TypeDescription) getMetadata() *typeMetadata {
meta, ok := td.metadata.Load().(*typeMetadata)
if ok {
return meta
}
meta = td.makeMetadata()
td.metadata.Store(meta)
return meta
td.init.Do(func() {
td.metadata = td.makeMetadata()
})
return td.metadata
}

func (td *TypeDescription) makeMetadata() *typeMetadata {
Expand Down Expand Up @@ -162,7 +189,6 @@ func (td *TypeDescription) newFieldDesc(
index int) *FieldDescription {
getterName := fmt.Sprintf("Get%s", prop.Name)
getter, _ := tdType.MethodByName(getterName)
isProto3 := td.file.desc.GetSyntax() == "proto3"
var field *reflect.StructField
if tdType.Kind() == reflect.Ptr {
tdType = tdType.Elem()
Expand All @@ -177,12 +203,12 @@ func (td *TypeDescription) newFieldDesc(
getter: getter.Func,
field: field,
prop: prop,
isProto3: isProto3,
isProto3: td.isProto3,
isWrapper: isWrapperType(desc),
}
if desc.GetType() == descpb.FieldDescriptorProto_TYPE_MESSAGE {
typeName := sanitizeProtoName(desc.GetTypeName())
fieldType, _ := td.file.pbdb.DescribeType(typeName)
fieldType, _ := td.resolveType(typeName)
fieldDesc.td = fieldType
return fieldDesc
}
Expand All @@ -203,7 +229,7 @@ func (td *TypeDescription) newMapFieldDesc(desc *descpb.FieldDescriptorProto) *F
return &FieldDescription{
desc: desc,
index: int(desc.GetNumber()),
isProto3: td.file.desc.GetSyntax() == "proto3",
isProto3: td.isProto3,
}
}

Expand Down

0 comments on commit 525447a

Please sign in to comment.