diff --git a/private/buf/bufctl/controller.go b/private/buf/bufctl/controller.go index c57ba10666..617b2e9c31 100644 --- a/private/buf/bufctl/controller.go +++ b/private/buf/bufctl/controller.go @@ -1332,7 +1332,7 @@ func filterImage( newImage = bufimage.ImageWithoutImports(newImage) } if len(functionOptions.imageTypes) > 0 { - newImage, err = bufimageutil.ImageFilteredByTypes(newImage, functionOptions.imageTypes...) + newImage, err = bufimageutil.FilterImage(newImage, bufimageutil.WithIncludeTypes(functionOptions.imageTypes...)) if err != nil { return nil, err } diff --git a/private/buf/bufgen/generator.go b/private/buf/bufgen/generator.go index 7a2aa5dd92..878a55d083 100644 --- a/private/buf/bufgen/generator.go +++ b/private/buf/bufgen/generator.go @@ -320,7 +320,7 @@ func (g *generator) execLocalPlugin( } if len(excludeOptions) > 0 { for i, pluginImage := range pluginImages { - pluginImage, err := bufimageutil.ExcludeOptions(pluginImage, excludeOptions...) + pluginImage, err := bufimageutil.FilterImage(pluginImage, bufimageutil.WithExcludeOptions(excludeOptions...)) if err != nil { return nil, err } diff --git a/private/buf/cmd/buf/command/convert/convert.go b/private/buf/cmd/buf/command/convert/convert.go index f246f60862..0bd5bf5d94 100644 --- a/private/buf/cmd/buf/command/convert/convert.go +++ b/private/buf/cmd/buf/command/convert/convert.go @@ -185,7 +185,7 @@ func run( resolveWellKnownType = true } if schemaImage != nil { - _, filterErr := bufimageutil.ImageFilteredByTypes(schemaImage, flags.Type) + _, filterErr := bufimageutil.FilterImage(schemaImage, bufimageutil.WithIncludeTypes(flags.Type)) if errors.Is(filterErr, bufimageutil.ErrImageFilterTypeNotFound) { resolveWellKnownType = true } @@ -283,5 +283,5 @@ func wellKnownTypeImage( if err != nil { return nil, err } - return bufimageutil.ImageFilteredByTypes(image, wellKnownTypeName) + return bufimageutil.FilterImage(image, bufimageutil.WithIncludeTypes(wellKnownTypeName)) } diff --git a/private/bufpkg/bufimage/bufimageutil/bufimageutil.go b/private/bufpkg/bufimage/bufimageutil/bufimageutil.go index 6c88624e43..1174968bc8 100644 --- a/private/bufpkg/bufimage/bufimageutil/bufimageutil.go +++ b/private/bufpkg/bufimage/bufimageutil/bufimageutil.go @@ -144,14 +144,22 @@ func WithExcludeOptions(typeNames ...string) ImageFilterOption { } } -// ImageFilteredByTypes returns a minimal image containing only the descriptors -// required to define those types. The resulting contains only files in which -// those descriptors and their transitive closure of required descriptors, with -// each file only contains the minimal required types and imports. +// FilterImage returns a minimal image containing only the descriptors +// required to define the set of types provided by the filter options. If no +// filter options are provided, the original image is returned. // -// Although this returns a new [bufimage.Image], it mutates the original image's -// underlying file's [descriptorpb.FileDescriptorProto]. So the old image should -// not continue to be used. +// The filtered image will contain only the files that contain the definitions of +// the specified types, and their transitive dependencies. If a file is no longer +// required, it will be removed from the image. Only the minimal set of types +// required to define the specified types will be included in the filtered image. +// +// Excluded types and options are not included in the filtered image. If an +// included type transitively depens on the excluded type, the descriptor will +// be altered to remove the dependency. +// +// This returns a new [bufimage.Image] that is a shallow copy of the underlying +// [descriptorpb.FileDescriptorProto]s of the original. The new image may therefore +// share state with the original image, so it should not be modified. // // A descriptor is said to require another descriptor if the dependent // descriptor is needed to accurately and completely describe that descriptor. @@ -212,10 +220,6 @@ func WithExcludeOptions(typeNames ...string) ImageFilterOption { // files: [foo.proto, bar.proto] // messages: [pkg.Baz, other.Quux, other.Qux] // extensions: [other.my_option] -func ImageFilteredByTypes(image bufimage.Image, types ...string) (bufimage.Image, error) { - return ImageFilteredByTypesWithOptions(image, types) -} - func FilterImage(image bufimage.Image, options ...ImageFilterOption) (bufimage.Image, error) { if len(options) == 0 { return image, nil @@ -227,13 +231,6 @@ func FilterImage(image bufimage.Image, options ...ImageFilterOption) (bufimage.I return filterImage(image, filterOptions) } -// ImageFilteredByTypesWithOptions returns a minimal image containing only the descriptors -// required to define those types. See ImageFilteredByTypes for more details. This version -// allows for customizing the behavior with options. -func ImageFilteredByTypesWithOptions(image bufimage.Image, types []string, opts ...ImageFilterOption) (bufimage.Image, error) { - return FilterImage(image, append(opts, WithIncludeTypes(types...))...) -} - // StripSourceRetentionOptions strips any options with a retention of "source" from // the descriptors in the given image. The image is not mutated but instead a new // image is returned. The returned image may share state with the original. diff --git a/private/bufpkg/bufimage/bufimageutil/bufimageutil_test.go b/private/bufpkg/bufimage/bufimageutil/bufimageutil_test.go index fac16862b8..84080ba74e 100644 --- a/private/bufpkg/bufimage/bufimageutil/bufimageutil_test.go +++ b/private/bufpkg/bufimage/bufimageutil/bufimageutil_test.go @@ -180,7 +180,7 @@ func TestTransitivePublic(t *testing.T) { ) require.NoError(t, err) - filteredImage, err := ImageFilteredByTypes(image, "c.Baz") + filteredImage, err := FilterImage(image, WithIncludeTypes("c.Baz")) require.NoError(t, err) _, err = protodesc.NewFiles(bufimage.ImageToFileDescriptorSet(filteredImage)) @@ -220,15 +220,15 @@ func TestTypesFromMainModule(t *testing.T) { bProtoFileInfo, err := dep.StatFileInfo(ctx, "b.proto") require.NoError(t, err) require.False(t, bProtoFileInfo.IsTargetFile()) - _, err = ImageFilteredByTypes(image, "dependency.Dep") + _, err = FilterImage(image, WithIncludeTypes("dependency.Dep")) require.Error(t, err) assert.ErrorIs(t, err, ErrImageFilterTypeIsImport) // allowed if we specify option - _, err = ImageFilteredByTypesWithOptions(image, []string{"dependency.Dep"}, WithAllowFilterByImportedType()) + _, err = FilterImage(image, WithIncludeTypes("dependency.Dep"), WithAllowFilterByImportedType()) require.NoError(t, err) - _, err = ImageFilteredByTypes(image, "nonexisting") + _, err = FilterImage(image, WithIncludeTypes("nonexisting")) require.Error(t, err) assert.ErrorIs(t, err, ErrImageFilterTypeNotFound) } @@ -259,7 +259,7 @@ func runDiffTest(t *testing.T, testdataDir string, typenames []string, expectedF bucket, image, err := getImage(ctx, slogtestext.NewLogger(t), testdataDir, bufimage.WithExcludeSourceCodeInfo()) require.NoError(t, err) - filteredImage, err := ImageFilteredByTypesWithOptions(image, typenames, opts...) + filteredImage, err := FilterImage(image, append(opts, WithIncludeTypes(typenames...))...) require.NoError(t, err) assert.NotNil(t, image) assert.True(t, imageIsDependencyOrdered(filteredImage), "image files not in dependency order") @@ -323,7 +323,7 @@ func runSourceCodeInfoTest(t *testing.T, typename string, expectedFile string, o bucket, image, err := getImage(ctx, slogtestext.NewLogger(t), "testdata/sourcecodeinfo") require.NoError(t, err) - filteredImage, err := ImageFilteredByTypesWithOptions(image, []string{typename}, opts...) + filteredImage, err := FilterImage(image, append(opts, WithIncludeTypes(typename))...) require.NoError(t, err) imageFile := filteredImage.GetFile("test.proto") @@ -477,7 +477,7 @@ func benchmarkFilterImage(b *testing.B, opts ...bufimage.BuildImageOption) { require.NoError(b, err) b.StartTimer() - _, err = ImageFilteredByTypes(image, typeName) + _, err = FilterImage(image, WithIncludeTypes(typeName)) require.NoError(b, err) i++ if i == b.N { diff --git a/private/bufpkg/bufimage/bufimageutil/image_filter.go b/private/bufpkg/bufimage/bufimageutil/image_filter.go index a5c3895f04..fe4432ba86 100644 --- a/private/bufpkg/bufimage/bufimageutil/image_filter.go +++ b/private/bufpkg/bufimage/bufimageutil/image_filter.go @@ -37,7 +37,6 @@ func filterImage(image bufimage.Image, options *imageFilterOptions) (bufimage.Im if err != nil { return nil, err } - // Loop over image files in revserse DAG order. Imports that are no longer // imported by a previous file are dropped from the image. imageFiles := image.Files() @@ -57,8 +56,6 @@ func filterImage(image bufimage.Image, options *imageFilterOptions) (bufimage.Im imageFile, imageIndex, filter, - //typeFilter, - //optionsFilter, ) if err != nil { return nil, err @@ -84,8 +81,6 @@ func filterImageFile( imageFile bufimage.ImageFile, imageIndex *imageIndex, filter *fullNameFilter, - //typesFilter fullNameFilter, - //optionsFilter fullNameFilter, ) (bufimage.ImageFile, error) { fileDescriptor := imageFile.FileDescriptorProto() var sourcePathsRemap sourcePathsRemapTrie @@ -132,15 +127,11 @@ func addRemapsForFileDescriptor( fileDescriptor *descriptorpb.FileDescriptorProto, imageIndex *imageIndex, filter *fullNameFilter, - //typesFilter fullNameFilter, - //optionsFilter fullNameFilter, ) (bool, error) { packageName := protoreflect.FullName(fileDescriptor.GetPackage()) if packageName != "" { // Check if filtered by the package name. - isIncluded, isExplicit := filter.hasType(packageName) - if !isIncluded && isExplicit { - // The package is excluded. + if !filter.hasType(packageName) { return false, nil } } @@ -155,18 +146,27 @@ func addRemapsForFileDescriptor( sourcePath := make(protoreflect.SourcePath, 0, 8) // Walk the file descriptor. - if _, err := addRemapsForSlice(sourcePathsRemap, packageName, append(sourcePath, fileMessagesTag), fileDescriptor.MessageType, builder.addRemapsForDescriptor); err != nil { + isIncluded := false + hasMessages, err := addRemapsForSlice(sourcePathsRemap, packageName, append(sourcePath, fileMessagesTag), fileDescriptor.MessageType, builder.addRemapsForDescriptor) + if err != nil { return false, err } - if _, err := addRemapsForSlice(sourcePathsRemap, packageName, append(sourcePath, fileEnumsTag), fileDescriptor.EnumType, builder.addRemapsForEnum); err != nil { + isIncluded = isIncluded || hasMessages + hasEnums, err := addRemapsForSlice(sourcePathsRemap, packageName, append(sourcePath, fileEnumsTag), fileDescriptor.EnumType, builder.addRemapsForEnum) + if err != nil { return false, err } - if _, err := addRemapsForSlice(sourcePathsRemap, packageName, append(sourcePath, fileServicesTag), fileDescriptor.Service, builder.addRemapsForService); err != nil { + isIncluded = isIncluded || hasEnums + hasServices, err := addRemapsForSlice(sourcePathsRemap, packageName, append(sourcePath, fileServicesTag), fileDescriptor.Service, builder.addRemapsForService) + if err != nil { return false, err } - if _, err := addRemapsForSlice(sourcePathsRemap, packageName, append(sourcePath, fileExtensionsTag), fileDescriptor.Extension, builder.addRemapsForField); err != nil { + isIncluded = isIncluded || hasServices + hasExtensions, err := addRemapsForSlice(sourcePathsRemap, packageName, append(sourcePath, fileExtensionsTag), fileDescriptor.Extension, builder.addRemapsForField) + if err != nil { return false, err } + isIncluded = isIncluded || hasExtensions if err := builder.addRemapsForOptions(sourcePathsRemap, append(sourcePath, fileOptionsTag), fileDescriptor.Options); err != nil { return false, err } @@ -211,7 +211,7 @@ func addRemapsForFileDescriptor( } } } - return true, nil + return isIncluded, nil } func (b *sourcePathsBuilder) addRemapsForDescriptor( @@ -221,64 +221,44 @@ func (b *sourcePathsBuilder) addRemapsForDescriptor( descriptor *descriptorpb.DescriptorProto, ) (bool, error) { fullName := getFullName(parentName, descriptor) - isIncluded, isExplicit := b.filter.hasType(fullName) - if !isIncluded && isExplicit { + mode := b.filter.inclusionMode(fullName) + if mode == inclusionModeNone { // The type is excluded. return false, nil } - // - // If the message is only enclosin included message remove the fields. - if isIncluded { - if _, err := addRemapsForSlice(sourcePathsRemap, fullName, append(sourcePath, messageFieldsTag), descriptor.GetField(), b.addRemapsForField); err != nil { - return false, err - } - if _, err := addRemapsForSlice(sourcePathsRemap, fullName, append(sourcePath, messageExtensionsTag), descriptor.GetExtension(), b.addRemapsForField); err != nil { - return false, err - } - for index, extensionRange := range descriptor.GetExtensionRange() { - extensionRangeOptionsPath := append(sourcePath, messageExtensionRangesTag, int32(index), extensionRangeOptionsTag) - if err := b.addRemapsForOptions(sourcePathsRemap, extensionRangeOptionsPath, extensionRange.GetOptions()); err != nil { - return false, err - } - } - } else { - sourcePathsRemap.markDeleted(append(sourcePath, messageFieldsTag)) - sourcePathsRemap.markDeleted(append(sourcePath, messageOneofsTag)) - // TODO: check if extensions are removed??? - sourcePathsRemap.markDeleted(append(sourcePath, messageExtensionRangesTag)) - sourcePathsRemap.markDeleted(append(sourcePath, messageExtensionRangesTag)) + if mode == inclusionModeEnclosing { + // TODO: check if other descriptor fields are removed? sourcePathsRemap.markDeleted(append(sourcePath, messageReservedRangesTag)) sourcePathsRemap.markDeleted(append(sourcePath, messageReservedNamesTag)) - //for index := range descriptor.GetExtensionRange() { - // sourcePathsRemap.markDeleted(append(sourcePath, messageExtensionRangesTag, int32(index), extensionRangeOptionsTag)) - //} - //sourcePathsRemap.markDeleted(append(sourcePath, messageOptionsTag)) } - // Walk the nested types. - hasNestedTypes, err := addRemapsForSlice(sourcePathsRemap, fullName, append(sourcePath, messageNestedMessagesTag), descriptor.NestedType, b.addRemapsForDescriptor) - if err != nil { + + // If the message is only enclosing, we search all fields for extensions. + if _, err := addRemapsForSlice(sourcePathsRemap, fullName, append(sourcePath, messageFieldsTag), descriptor.GetField(), b.addRemapsForField); err != nil { return false, err } - isIncluded = isIncluded || hasNestedTypes - - // Walk the enum types. - hasEnums, err := addRemapsForSlice(sourcePathsRemap, fullName, append(sourcePath, messageEnumsTag), descriptor.EnumType, b.addRemapsForEnum) - if err != nil { + if _, err := addRemapsForSlice(sourcePathsRemap, fullName, append(sourcePath, messageExtensionsTag), descriptor.GetExtension(), b.addRemapsForField); err != nil { return false, err } - isIncluded = isIncluded || hasEnums + for index, extensionRange := range descriptor.GetExtensionRange() { + extensionRangeOptionsPath := append(sourcePath, messageExtensionRangesTag, int32(index), extensionRangeOptionsTag) + if err := b.addRemapsForOptions(sourcePathsRemap, extensionRangeOptionsPath, extensionRange.GetOptions()); err != nil { + return false, err + } + } - // Walk the oneof types. - hasOneofs, err := addRemapsForSlice(sourcePathsRemap, fullName, append(sourcePath, messageOneofsTag), descriptor.OneofDecl, b.addRemapsForOneof) - if err != nil { + if _, err := addRemapsForSlice(sourcePathsRemap, fullName, append(sourcePath, messageNestedMessagesTag), descriptor.NestedType, b.addRemapsForDescriptor); err != nil { + return false, err + } + if _, err := addRemapsForSlice(sourcePathsRemap, fullName, append(sourcePath, messageEnumsTag), descriptor.EnumType, b.addRemapsForEnum); err != nil { + return false, err + } + if _, err := addRemapsForSlice(sourcePathsRemap, fullName, append(sourcePath, messageOneofsTag), descriptor.OneofDecl, b.addRemapsForOneof); err != nil { return false, err } - isIncluded = isIncluded || hasOneofs - if err := b.addRemapsForOptions(sourcePathsRemap, append(sourcePath, messageOptionsTag), descriptor.GetOptions()); err != nil { return false, err } - return isIncluded, nil + return true, nil } func (b *sourcePathsBuilder) addRemapsForEnum( @@ -289,7 +269,7 @@ func (b *sourcePathsBuilder) addRemapsForEnum( ) (bool, error) { //fullName := b.imageIndex.ByDescriptor[enum] fullName := getFullName(parentName, enum) - if isIncluded, _ := b.filter.hasType(fullName); !isIncluded { + if !b.filter.hasType(fullName) { // The type is excluded, enum values cannot be excluded individually. return false, nil } @@ -316,7 +296,7 @@ func (b *sourcePathsBuilder) addRemapsForOneof( oneof *descriptorpb.OneofDescriptorProto, ) (bool, error) { fullName := getFullName(parentName, oneof) - if isIncluded, _ := b.filter.hasType(fullName); !isIncluded { + if !b.filter.hasType(fullName) { // The type is excluded, enum values cannot be excluded individually. return false, nil } @@ -333,22 +313,18 @@ func (b *sourcePathsBuilder) addRemapsForService( service *descriptorpb.ServiceDescriptorProto, ) (bool, error) { fullName := getFullName(parentName, service) - isIncluded, isExplicit := b.filter.hasType(fullName) - if !isIncluded && isExplicit { + if !b.filter.hasType(fullName) { // The type is excluded. return false, nil } - if isIncluded { - if err := b.addRemapsForOptions(sourcePathsRemap, append(sourcePath, serviceOptionsTag), service.GetOptions()); err != nil { - return false, err - } - } // Walk the service methods. - hasMethods, err := addRemapsForSlice(sourcePathsRemap, fullName, append(sourcePath, serviceMethodsTag), service.Method, b.addRemapsForMethod) - if err != nil { + if _, err := addRemapsForSlice(sourcePathsRemap, fullName, append(sourcePath, serviceMethodsTag), service.Method, b.addRemapsForMethod); err != nil { return false, err } - return isIncluded || hasMethods, nil + if err := b.addRemapsForOptions(sourcePathsRemap, append(sourcePath, serviceOptionsTag), service.GetOptions()); err != nil { + return false, err + } + return true, nil } func (b *sourcePathsBuilder) addRemapsForMethod( @@ -358,18 +334,18 @@ func (b *sourcePathsBuilder) addRemapsForMethod( method *descriptorpb.MethodDescriptorProto, ) (bool, error) { fullName := getFullName(parentName, method) - if isIncluded, _ := b.filter.hasType(fullName); !isIncluded { + if !b.filter.hasType(fullName) { // The type is excluded. return false, nil } inputName := protoreflect.FullName(strings.TrimPrefix(method.GetInputType(), ".")) - if isIncluded, _ := b.filter.hasType(inputName); !isIncluded { + if !b.filter.hasType(inputName) { // The input type is excluded. return false, fmt.Errorf("input type %s of method %s is excluded", inputName, fullName) } b.addRequiredType(inputName) outputName := protoreflect.FullName(strings.TrimPrefix(method.GetOutputType(), ".")) - if isIncluded, _ := b.filter.hasType(outputName); !isIncluded { + if !b.filter.hasType(outputName) { // The output type is excluded. return false, fmt.Errorf("output type %s of method %s is excluded", outputName, fullName) } @@ -389,17 +365,19 @@ func (b *sourcePathsBuilder) addRemapsForField( if field.Extendee != nil { // This is an extension field. extendeeName := protoreflect.FullName(strings.TrimPrefix(field.GetExtendee(), ".")) - if isIncluded, _ := b.filter.hasType(extendeeName); !isIncluded { + if !b.filter.hasType(extendeeName) { return false, nil } b.addRequiredType(extendeeName) + } else if b.filter.inclusionMode(parentName) == inclusionModeEnclosing { + return false, nil // The field is excluded. } switch field.GetType() { case descriptorpb.FieldDescriptorProto_TYPE_ENUM, descriptorpb.FieldDescriptorProto_TYPE_MESSAGE, descriptorpb.FieldDescriptorProto_TYPE_GROUP: typeName := protoreflect.FullName(strings.TrimPrefix(field.GetTypeName(), ".")) - if isIncluded, _ := b.filter.hasType(typeName); !isIncluded { + if !b.filter.hasType(typeName) { return false, nil } b.addRequiredType(typeName) @@ -438,9 +416,7 @@ func (b *sourcePathsBuilder) addRemapsForOptions( options := optionsMessage.ProtoReflect() numFieldsToKeep := 0 options.Range(func(fd protoreflect.FieldDescriptor, val protoreflect.Value) bool { - - isIncluded, _ := b.filter.hasOption(fd.FullName(), fd.IsExtension()) - if !isIncluded { + if !b.filter.hasOption(fd.FullName(), fd.IsExtension()) { // Remove this option. optionPath := append(optionsPath, int32(fd.Number())) sourcePathsRemap.markDeleted(optionPath) @@ -627,11 +603,6 @@ func remapListReflect( if fromIndex != int(remapNode.oldIndex) || toIndex != int(remapNode.newIndex) { return fmt.Errorf("unexpected list move %d to %d, expected %d to %d", remapNode.oldIndex, remapNode.newIndex, fromIndex, toIndex) } - //if toIndex != int(remapNode.newIndex) { - // // Mutate the remap node to reflect the actual index. - // // TODO: this is a hack. - // remapNode.newIndex = int32(toIndex) - //} // If no children, the value is unchanged. if len(remapNode.children) > 0 { // Must be a list of messages to have children. diff --git a/private/bufpkg/bufimage/bufimageutil/image_index.go b/private/bufpkg/bufimage/bufimageutil/image_index.go index 575f8b991f..240f8be92a 100644 --- a/private/bufpkg/bufimage/bufimageutil/image_index.go +++ b/private/bufpkg/bufimage/bufimageutil/image_index.go @@ -37,15 +37,12 @@ type imageIndex struct { ByName map[protoreflect.FullName]elementInfo // Files maps file names to the file descriptor protos. Files map[string]*descriptorpb.FileDescriptorProto - // NameToExtensions maps fully qualified type names to all known // extension definitions for a type name. NameToExtensions map[string][]*descriptorpb.FieldDescriptorProto - // NameToOptions maps `google.protobuf.*Options` type names to their // known extensions by field tag. NameToOptions map[string]map[int32]*descriptorpb.FieldDescriptorProto - // Packages maps package names to package contents. Packages map[string]*packageInfo } @@ -93,12 +90,28 @@ func newImageIndexForImage(image bufimage.Image, options *imageFilterOptions) (* index.NameToExtensions = make(map[string][]*descriptorpb.FieldDescriptorProto) } + addExtension := func(ext *descriptorpb.FieldDescriptorProto) { + extendeeName := strings.TrimPrefix(ext.GetExtendee(), ".") + if options.includeCustomOptions && isOptionsTypeName(extendeeName) { + if _, ok := index.NameToOptions[extendeeName]; !ok { + index.NameToOptions[extendeeName] = make(map[int32]*descriptorpb.FieldDescriptorProto) + } + index.NameToOptions[extendeeName][ext.GetNumber()] = ext + } + if options.includeKnownExtensions { + index.NameToExtensions[extendeeName] = append(index.NameToExtensions[extendeeName], ext) + } + } + for _, imageFile := range image.Files() { pkg := addPackageToIndex(imageFile.FileDescriptorProto().GetPackage(), index) pkg.files = append(pkg.files, imageFile) fileName := imageFile.Path() fileDescriptorProto := imageFile.FileDescriptorProto() index.Files[fileName] = fileDescriptorProto + for _, fd := range fileDescriptorProto.GetExtension() { + addExtension(fd) + } err := walk.DescriptorProtos(fileDescriptorProto, func(name protoreflect.FullName, msg proto.Message) error { if _, existing := index.ByName[name]; existing { return fmt.Errorf("duplicate for %q", name) @@ -137,21 +150,8 @@ func newImageIndexForImage(image bufimage.Image, options *imageFilterOptions) (* pkg.types = append(pkg.types, name) } - ext, ok := descriptor.(*descriptorpb.FieldDescriptorProto) - if !ok || ext.Extendee == nil { - // not an extension, so the rest does not apply - return nil - } - - extendeeName := strings.TrimPrefix(ext.GetExtendee(), ".") - if options.includeCustomOptions && isOptionsTypeName(extendeeName) { - if _, ok := index.NameToOptions[extendeeName]; !ok { - index.NameToOptions[extendeeName] = make(map[int32]*descriptorpb.FieldDescriptorProto) - } - index.NameToOptions[extendeeName][ext.GetNumber()] = ext - } - if options.includeKnownExtensions { - index.NameToExtensions[extendeeName] = append(index.NameToExtensions[extendeeName], ext) + if ext, ok := descriptor.(*descriptorpb.FieldDescriptorProto); ok && ext.GetExtendee() != "" { + addExtension(ext) } return nil }) diff --git a/private/bufpkg/bufimage/bufimageutil/image_types.go b/private/bufpkg/bufimage/bufimageutil/image_types.go index 80296ccd79..85e8f35c30 100644 --- a/private/bufpkg/bufimage/bufimageutil/image_types.go +++ b/private/bufpkg/bufimage/bufimageutil/image_types.go @@ -23,11 +23,20 @@ import ( "google.golang.org/protobuf/types/descriptorpb" ) +type inclusionMode int + +const ( + inclusionModeNone inclusionMode = iota + inclusionModeEnclosing + inclusionModeExplicit +) + type fullNameFilter struct { options *imageFilterOptions index *imageIndex - includes map[protoreflect.FullName]struct{} + includes map[protoreflect.FullName]inclusionMode excludes map[protoreflect.FullName]struct{} + depth int } func newFullNameFilter( @@ -53,38 +62,42 @@ func newFullNameFilter( return nil, err } } + if err := filter.includeExtensions(); err != nil { + return nil, err + } return filter, nil } -func (f *fullNameFilter) hasType(fullName protoreflect.FullName) (isIncluded bool, isExplicit bool) { - if len(f.options.excludeTypes) > 0 || f.excludes != nil { +func (f *fullNameFilter) inclusionMode(fullName protoreflect.FullName) inclusionMode { + if f.excludes != nil { if _, ok := f.excludes[fullName]; ok { - return false, true + return inclusionModeNone } } - if len(f.options.includeTypes) > 0 || f.includes != nil { - if _, ok := f.includes[fullName]; ok { - return true, true - } - return false, false + if f.includes != nil { + return f.includes[fullName] } - return true, false + return inclusionModeExplicit +} + +func (f *fullNameFilter) hasType(fullName protoreflect.FullName) (isIncluded bool) { + return f.inclusionMode(fullName) != inclusionModeNone } -func (f *fullNameFilter) hasOption(fullName protoreflect.FullName, isExtension bool) (isIncluded bool, isExplicit bool) { +func (f *fullNameFilter) hasOption(fullName protoreflect.FullName, isExtension bool) (isIncluded bool) { if f.options.excludeOptions != nil { if _, ok := f.options.excludeOptions[string(fullName)]; ok { - return false, true + return false } } if !f.options.includeCustomOptions { - return !isExtension, true + return !isExtension } if f.options.includeOptions != nil { _, ok := f.options.includeOptions[string(fullName)] - return ok, true + return ok } - return true, false + return true } func (f *fullNameFilter) isExplicitExclude(fullName protoreflect.FullName) bool { @@ -142,7 +155,6 @@ func (f *fullNameFilter) excludeElement(fullName protoreflect.FullName, descript } return nil case *descriptorpb.DescriptorProto: - f.excludes[fullName] = struct{}{} // Exclude all sub-elements if err := forEachDescriptor(fullName, descriptor.GetNestedType(), f.excludeElement); err != nil { return err @@ -176,11 +188,31 @@ func (f *fullNameFilter) include(fullName protoreflect.FullName) error { return nil } if descriptorInfo, ok := f.index.ByName[fullName]; ok { + if err := f.includeElement(fullName, descriptorInfo.element); err != nil { + return err + } // Include the enclosing parent options. - if err := f.includeEnclosingOptions(descriptorInfo.parentName); err != nil { + fileDescriptor := descriptorInfo.imageFile.FileDescriptorProto() + if err := f.includeOptions(fileDescriptor); err != nil { return err } - return f.includeElement(fullName, descriptorInfo.element) + // loop through all enclosing parents since nesting level + // could be arbitrarily deep + for parentName := descriptorInfo.parentName; parentName != ""; { + if isIncluded := f.hasType(parentName); isIncluded { + break + } + f.includes[parentName] = inclusionModeEnclosing + parentInfo, ok := f.index.ByName[parentName] + if !ok { + break + } + if err := f.includeOptions(parentInfo.element); err != nil { + return err + } + parentName = parentInfo.parentName + } + return nil } packageInfo, ok := f.index.Packages[string(fullName)] if !ok { @@ -193,7 +225,6 @@ func (f *fullNameFilter) include(fullName protoreflect.FullName) error { return err } } - f.includes[fullName] = struct{}{} for _, subPackage := range packageInfo.subPackages { if err := f.include(subPackage.fullName); err != nil { return err @@ -210,9 +241,9 @@ func (f *fullNameFilter) includeElement(fullName protoreflect.FullName, descript return nil // already excluded } if f.includes == nil { - f.includes = make(map[protoreflect.FullName]struct{}) + f.includes = make(map[protoreflect.FullName]inclusionMode) } - f.includes[fullName] = struct{}{} + f.includes[fullName] = inclusionModeExplicit if err := f.includeOptions(descriptor); err != nil { return err @@ -243,18 +274,8 @@ func (f *fullNameFilter) includeElement(fullName protoreflect.FullName, descript if err := forEachDescriptor(fullName, descriptor.GetField(), f.includeElement); err != nil { return err } - if err := forEachDescriptor(fullName, descriptor.GetExtension(), f.includeElement); err != nil { - return err - } - - // TODO: Include known extensions. - //if f.options.includeKnownExtensions { - // for _, extensionDescriptor := range f.index.NameToExtensions[string(fullName)] { - // if err := f.includeElement(fullName, extensionDescriptor); err != nil { - // return err - // } - // } - //} + // Extensions are handled after all elements are included. + // This allows us to ensure that the extendee is included first. return nil case *descriptorpb.FieldDescriptorProto: if descriptor.Extendee != nil { @@ -329,19 +350,34 @@ func (f *fullNameFilter) includeElement(fullName protoreflect.FullName, descript } } -func (f *fullNameFilter) includeEnclosingOptions(fullName protoreflect.FullName) error { - // loop through all enclosing parents since nesting level - // could be arbitrarily deep - for info, ok := f.index.ByName[fullName]; ok; { - if err := f.includeOptions(info.element); err != nil { - return err +func (f *fullNameFilter) includeExtensions() error { + if !f.options.includeKnownExtensions { + return nil // nothing to do + } + // TODO + /*if f.options.includeKnownExtensions && len(f.options.includeTypes) > 0 { + for extendee, extensions := range f.index.NameToExtensions { + extendee := protoreflect.FullName(extendee) + if f.inclusionMode(extendee)== inclusionModeNone { + continue + } + for _, extension := range extensions { + typeName := protoreflect.FullName(strings.TrimPrefix(extension.GetTypeName(), ".")) + if typeName == "" && f.hasType(typeName) { + if err := f.includeElement(extendee, extension); err != nil { + return err + } + } + } } - info, ok = f.index.ByName[info.parentName] - } + }*/ return nil } func (f *fullNameFilter) includeOptions(descriptor proto.Message) (err error) { + if !f.options.includeCustomOptions { + return nil + } var optionsMessage proto.Message switch descriptor := descriptor.(type) { case *descriptorpb.FileDescriptorProto: @@ -372,7 +408,7 @@ func (f *fullNameFilter) includeOptions(descriptor proto.Message) (err error) { optionsName := options.Descriptor().FullName() optionsByNumber := f.index.NameToOptions[string(optionsName)] options.Range(func(fieldDescriptor protoreflect.FieldDescriptor, value protoreflect.Value) bool { - if isIncluded, _ := f.hasOption(fieldDescriptor.FullName(), fieldDescriptor.IsExtension()); !isIncluded { + if !f.hasOption(fieldDescriptor.FullName(), fieldDescriptor.IsExtension()) { return true } if err = f.includeOptionValue(fieldDescriptor, value); err != nil { @@ -381,6 +417,7 @@ func (f *fullNameFilter) includeOptions(descriptor proto.Message) (err error) { if !fieldDescriptor.IsExtension() { return true } + extensionField, ok := optionsByNumber[int32(fieldDescriptor.Number())] if !ok { err = fmt.Errorf("cannot find ext no %d on %s", fieldDescriptor.Number(), optionsName)