From 4792071d3d3600bac58e494dc13cf34605e15db0 Mon Sep 17 00:00:00 2001 From: Ralf Haferkamp Date: Thu, 7 Nov 2024 12:00:38 +0100 Subject: [PATCH 1/2] fix(graph): Use the correct opaqueId when Statting OCM shares File shares need to use the base64 encoded path as the opaqueID, while for folder shares the base64 encode '/' should work. Fixes #10495 --- changelog/unreleased/fix-ocm-file-preview.md | 7 +++++++ services/graph/pkg/service/v0/utils.go | 12 ++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) create mode 100644 changelog/unreleased/fix-ocm-file-preview.md diff --git a/changelog/unreleased/fix-ocm-file-preview.md b/changelog/unreleased/fix-ocm-file-preview.md new file mode 100644 index 00000000000..39835ae918e --- /dev/null +++ b/changelog/unreleased/fix-ocm-file-preview.md @@ -0,0 +1,7 @@ +Bugfix: Fixed `sharedWithMe` response for OCM shares + +OCM shares returned in the `sharedWithMe` response did not have the `mimeType` property +populated correctly. + +https://github.com/owncloud/ocis/pull/10501 +https://github.com/owncloud/ocis/issues/10495 diff --git a/services/graph/pkg/service/v0/utils.go b/services/graph/pkg/service/v0/utils.go index 02a15dea105..a5480536aca 100644 --- a/services/graph/pkg/service/v0/utils.go +++ b/services/graph/pkg/service/v0/utils.go @@ -2,6 +2,7 @@ package svc import ( "context" + "encoding/base64" "encoding/json" "io" "net/http" @@ -538,14 +539,21 @@ func cs3ReceivedOCMSharesToDriveItems(ctx context.Context, group.Go(func() error { var err error // redeclare + + // for OCM shares the opaqueID is the '/' for shared directories and '/filename' for + // file shares + resOpaqueID := "/" + if receivedShares[0].GetResourceType() == storageprovider.ResourceType_RESOURCE_TYPE_FILE { + resOpaqueID += receivedShares[0].GetName() + } + shareStat, err := gatewayClient.Stat(ctx, &storageprovider.StatRequest{ Ref: &storageprovider.Reference{ ResourceId: &storageprovider.ResourceId{ // TODO maybe the reference is wrong StorageId: utils.OCMStorageProviderID, SpaceId: receivedShares[0].GetId().GetOpaqueId(), - OpaqueId: "", // in OCM resources the opaque id is the base64 encoded path - //OpaqueId: maybe ? receivedShares[0].GetId().GetOpaqueId(), + OpaqueId: base64.StdEncoding.EncodeToString([]byte(resOpaqueID)), }, }, }) From 336d34821fa9edf0b733117f02a16b5c04803ef7 Mon Sep 17 00:00:00 2001 From: Ralf Haferkamp Date: Thu, 7 Nov 2024 13:11:43 +0100 Subject: [PATCH 2/2] fix(linter): Avoid copying protobuf messages This shoud silence some linter warning about copying lock values. --- .../service/v0/api_driveitem_permissions.go | 8 +- .../v0/api_driveitem_permissions_links.go | 4 +- .../pkg/service/v0/api_drives_drive_item.go | 2 +- services/graph/pkg/service/v0/utils.go | 14 +- services/graph/pkg/service/v0/utils_test.go | 21 +- .../protobuf/internal/msgfmt/format.go | 261 +++++++ .../protobuf/testing/protocmp/reflect.go | 258 +++++++ .../protobuf/testing/protocmp/util.go | 684 ++++++++++++++++++ .../protobuf/testing/protocmp/xform.go | 377 ++++++++++ vendor/modules.txt | 2 + 10 files changed, 1607 insertions(+), 24 deletions(-) create mode 100644 vendor/google.golang.org/protobuf/internal/msgfmt/format.go create mode 100644 vendor/google.golang.org/protobuf/testing/protocmp/reflect.go create mode 100644 vendor/google.golang.org/protobuf/testing/protocmp/util.go create mode 100644 vendor/google.golang.org/protobuf/testing/protocmp/xform.go diff --git a/services/graph/pkg/service/v0/api_driveitem_permissions.go b/services/graph/pkg/service/v0/api_driveitem_permissions.go index a015ee64795..9765be44822 100644 --- a/services/graph/pkg/service/v0/api_driveitem_permissions.go +++ b/services/graph/pkg/service/v0/api_driveitem_permissions.go @@ -634,7 +634,7 @@ func (api DriveItemPermissionsApi) Invite(w http.ResponseWriter, r *http.Request return } - permission, err := api.driveItemPermissionsService.Invite(ctx, &itemID, *driveItemInvite) + permission, err := api.driveItemPermissionsService.Invite(ctx, itemID, *driveItemInvite) if err != nil { errorcode.RenderError(w, r, err) return @@ -698,7 +698,7 @@ func (api DriveItemPermissionsApi) ListPermissions(w http.ResponseWriter, r *htt ctx := r.Context() - permissions, err := api.driveItemPermissionsService.ListPermissions(ctx, &itemID, listFederatedRoles, selectRoles) + permissions, err := api.driveItemPermissionsService.ListPermissions(ctx, itemID, listFederatedRoles, selectRoles) if err != nil { errorcode.RenderError(w, r, err) return @@ -774,7 +774,7 @@ func (api DriveItemPermissionsApi) DeletePermission(w http.ResponseWriter, r *ht } ctx := r.Context() - err = api.driveItemPermissionsService.DeletePermission(ctx, &itemID, permissionID) + err = api.driveItemPermissionsService.DeletePermission(ctx, itemID, permissionID) if err != nil { errorcode.RenderError(w, r, err) return @@ -841,7 +841,7 @@ func (api DriveItemPermissionsApi) UpdatePermission(w http.ResponseWriter, r *ht return } - updatedPermission, err := api.driveItemPermissionsService.UpdatePermission(ctx, &itemID, permissionID, permission) + updatedPermission, err := api.driveItemPermissionsService.UpdatePermission(ctx, itemID, permissionID, permission) if err != nil { errorcode.RenderError(w, r, err) return diff --git a/services/graph/pkg/service/v0/api_driveitem_permissions_links.go b/services/graph/pkg/service/v0/api_driveitem_permissions_links.go index 86675f83ee0..1a66adf3993 100644 --- a/services/graph/pkg/service/v0/api_driveitem_permissions_links.go +++ b/services/graph/pkg/service/v0/api_driveitem_permissions_links.go @@ -166,7 +166,7 @@ func (api DriveItemPermissionsApi) CreateLink(w http.ResponseWriter, r *http.Req return } - perm, err := api.driveItemPermissionsService.CreateLink(r.Context(), &driveItemID, createLink) + perm, err := api.driveItemPermissionsService.CreateLink(r.Context(), driveItemID, createLink) if err != nil { errorcode.RenderError(w, r, err) return @@ -228,7 +228,7 @@ func (api DriveItemPermissionsApi) SetLinkPassword(w http.ResponseWriter, r *htt return } - newPermission, err := api.driveItemPermissionsService.SetPublicLinkPassword(ctx, &itemID, permissionID, password.GetPassword()) + newPermission, err := api.driveItemPermissionsService.SetPublicLinkPassword(ctx, itemID, permissionID, password.GetPassword()) if err != nil { errorcode.RenderError(w, r, err) return diff --git a/services/graph/pkg/service/v0/api_drives_drive_item.go b/services/graph/pkg/service/v0/api_drives_drive_item.go index 2e49586ef05..4c7d209a28e 100644 --- a/services/graph/pkg/service/v0/api_drives_drive_item.go +++ b/services/graph/pkg/service/v0/api_drives_drive_item.go @@ -499,7 +499,7 @@ func (api DrivesDriveItemApi) CreateDriveItem(w http.ResponseWriter, r *http.Req return } - if !IsShareJail(driveID) { + if !IsShareJail(&driveID) { api.logger.Debug().Interface("driveID", driveID).Msg(ErrNotAShareJail.Error()) ErrNotAShareJail.Render(w, r) return diff --git a/services/graph/pkg/service/v0/utils.go b/services/graph/pkg/service/v0/utils.go index a5480536aca..5a0936b8e1c 100644 --- a/services/graph/pkg/service/v0/utils.go +++ b/services/graph/pkg/service/v0/utils.go @@ -45,8 +45,8 @@ func IsSpaceRoot(rid *storageprovider.ResourceId) bool { // GetDriveAndItemIDParam parses the driveID and itemID from the request, // validates the common fields and returns the parsed IDs if ok. -func GetDriveAndItemIDParam(r *http.Request, logger *log.Logger) (storageprovider.ResourceId, storageprovider.ResourceId, error) { - empty := storageprovider.ResourceId{} +func GetDriveAndItemIDParam(r *http.Request, logger *log.Logger) (*storageprovider.ResourceId, *storageprovider.ResourceId, error) { + empty := &storageprovider.ResourceId{} driveID, err := parseIDParam(r, "driveID") if err != nil { @@ -61,16 +61,16 @@ func GetDriveAndItemIDParam(r *http.Request, logger *log.Logger) (storageprovide } if itemID.GetOpaqueId() == "" { - logger.Debug().Interface("driveID", driveID).Interface("itemID", itemID).Msg("empty item opaqueID") + logger.Debug().Interface("driveID", &driveID).Interface("itemID", &itemID).Msg("empty item opaqueID") return empty, empty, errorcode.New(errorcode.InvalidRequest, "invalid itemID") } if driveID.GetStorageId() != itemID.GetStorageId() || driveID.GetSpaceId() != itemID.GetSpaceId() { - logger.Debug().Interface("driveID", driveID).Interface("itemID", itemID).Msg("driveID and itemID do not match") + logger.Debug().Interface("driveID", &driveID).Interface("itemID", &itemID).Msg("driveID and itemID do not match") return empty, empty, errorcode.New(errorcode.ItemNotFound, "driveID and itemID do not match") } - return driveID, itemID, nil + return &driveID, &itemID, nil } // GetFilterParam returns the $filter query parameter from the request. If you need to parse the filter use godata.ParseRequest @@ -96,7 +96,7 @@ func (g Graph) GetGatewayClient(w http.ResponseWriter, r *http.Request) (gateway } // IsShareJail returns true if given id is a share jail id. -func IsShareJail(id storageprovider.ResourceId) bool { +func IsShareJail(id *storageprovider.ResourceId) bool { return id.GetStorageId() == utils.ShareStorageProviderID && id.GetSpaceId() == utils.ShareStorageSpaceID } @@ -511,7 +511,7 @@ func federatedRoleConditionForResourceType(ri *storageprovider.ResourceInfo) (st // ExtractShareIdFromResourceId is a bit of a hack. // We should not rely on a specific format of the item id. // But currently there is no other way to get the ShareID. -func ExtractShareIdFromResourceId(rid storageprovider.ResourceId) *collaboration.ShareId { +func ExtractShareIdFromResourceId(rid *storageprovider.ResourceId) *collaboration.ShareId { return &collaboration.ShareId{ OpaqueId: rid.GetOpaqueId(), } diff --git a/services/graph/pkg/service/v0/utils_test.go b/services/graph/pkg/service/v0/utils_test.go index 03b56197f20..b9dbc496312 100644 --- a/services/graph/pkg/service/v0/utils_test.go +++ b/services/graph/pkg/service/v0/utils_test.go @@ -15,6 +15,7 @@ import ( provider "github.com/cs3org/go-cs3apis/cs3/storage/provider/v1beta1" "github.com/cs3org/reva/v2/pkg/storagespace" + "google.golang.org/protobuf/testing/protocmp" "github.com/owncloud/ocis/v2/ocis-pkg/conversions" "github.com/owncloud/ocis/v2/ocis-pkg/log" @@ -42,10 +43,10 @@ var _ = Describe("Utils", func() { case true: Expect(err).To(BeNil()) parsedItemID, _ := storagespace.ParseID(itemID) - Expect(extractedItemID).To(Equal(parsedItemID)) + Expect(extractedItemID).To(BeComparableTo(&parsedItemID, protocmp.Transform())) parsedDriveID, _ := storagespace.ParseID(driveID) - Expect(extractedDriveID).To(Equal(parsedDriveID)) + Expect(extractedDriveID).To(BeComparableTo(&parsedDriveID, protocmp.Transform())) default: Expect(err).ToNot(BeNil()) } @@ -82,29 +83,29 @@ var _ = Describe("Utils", func() { ) DescribeTable("IsShareJail", - func(resourceID provider.ResourceId, isShareJail bool) { + func(resourceID *provider.ResourceId, isShareJail bool) { Expect(service.IsShareJail(resourceID)).To(Equal(isShareJail)) }, - Entry("valid: share jail", provider.ResourceId{ + Entry("valid: share jail", &provider.ResourceId{ StorageId: utils.ShareStorageProviderID, SpaceId: utils.ShareStorageSpaceID, }, true), - Entry("invalid: empty storageId", provider.ResourceId{ + Entry("invalid: empty storageId", &provider.ResourceId{ SpaceId: utils.ShareStorageSpaceID, }, false), - Entry("invalid: empty spaceId", provider.ResourceId{ + Entry("invalid: empty spaceId", &provider.ResourceId{ StorageId: utils.ShareStorageProviderID, }, false), - Entry("invalid: empty storageId and spaceId", provider.ResourceId{}, false), - Entry("invalid: non share jail storageId", provider.ResourceId{ + Entry("invalid: empty storageId and spaceId", &provider.ResourceId{}, false), + Entry("invalid: non share jail storageId", &provider.ResourceId{ StorageId: "123", SpaceId: utils.ShareStorageSpaceID, }, false), - Entry("invalid: non share jail spaceId", provider.ResourceId{ + Entry("invalid: non share jail spaceId", &provider.ResourceId{ StorageId: utils.ShareStorageProviderID, SpaceId: "123", }, false), - Entry("invalid: non share jail storageID and spaceId", provider.ResourceId{ + Entry("invalid: non share jail storageID and spaceId", &provider.ResourceId{ StorageId: "123", SpaceId: "123", }, false), diff --git a/vendor/google.golang.org/protobuf/internal/msgfmt/format.go b/vendor/google.golang.org/protobuf/internal/msgfmt/format.go new file mode 100644 index 00000000000..17b3f27b8a2 --- /dev/null +++ b/vendor/google.golang.org/protobuf/internal/msgfmt/format.go @@ -0,0 +1,261 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package msgfmt implements a text marshaler combining the desirable features +// of both the JSON and proto text formats. +// It is optimized for human readability and has no associated deserializer. +package msgfmt + +import ( + "bytes" + "fmt" + "reflect" + "sort" + "strconv" + "strings" + "time" + + "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/internal/detrand" + "google.golang.org/protobuf/internal/genid" + "google.golang.org/protobuf/internal/order" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" +) + +// Format returns a formatted string for the message. +func Format(m proto.Message) string { + return string(appendMessage(nil, m.ProtoReflect())) +} + +// FormatValue returns a formatted string for an arbitrary value. +func FormatValue(v protoreflect.Value, fd protoreflect.FieldDescriptor) string { + return string(appendValue(nil, v, fd)) +} + +func appendValue(b []byte, v protoreflect.Value, fd protoreflect.FieldDescriptor) []byte { + switch v := v.Interface().(type) { + case nil: + return append(b, ""...) + case bool, int32, int64, uint32, uint64, float32, float64: + return append(b, fmt.Sprint(v)...) + case string: + return append(b, strconv.Quote(string(v))...) + case []byte: + return append(b, strconv.Quote(string(v))...) + case protoreflect.EnumNumber: + return appendEnum(b, v, fd) + case protoreflect.Message: + return appendMessage(b, v) + case protoreflect.List: + return appendList(b, v, fd) + case protoreflect.Map: + return appendMap(b, v, fd) + default: + panic(fmt.Sprintf("invalid type: %T", v)) + } +} + +func appendEnum(b []byte, v protoreflect.EnumNumber, fd protoreflect.FieldDescriptor) []byte { + if fd != nil { + if ev := fd.Enum().Values().ByNumber(v); ev != nil { + return append(b, ev.Name()...) + } + } + return strconv.AppendInt(b, int64(v), 10) +} + +func appendMessage(b []byte, m protoreflect.Message) []byte { + if b2 := appendKnownMessage(b, m); b2 != nil { + return b2 + } + + b = append(b, '{') + order.RangeFields(m, order.IndexNameFieldOrder, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { + b = append(b, fd.TextName()...) + b = append(b, ':') + b = appendValue(b, v, fd) + b = append(b, delim()...) + return true + }) + b = appendUnknown(b, m.GetUnknown()) + b = bytes.TrimRight(b, delim()) + b = append(b, '}') + return b +} + +var protocmpMessageType = reflect.TypeOf(map[string]any(nil)) + +func appendKnownMessage(b []byte, m protoreflect.Message) []byte { + md := m.Descriptor() + fds := md.Fields() + switch md.FullName() { + case genid.Any_message_fullname: + var msgVal protoreflect.Message + url := m.Get(fds.ByNumber(genid.Any_TypeUrl_field_number)).String() + if v := reflect.ValueOf(m); v.Type().ConvertibleTo(protocmpMessageType) { + // For protocmp.Message, directly obtain the sub-message value + // which is stored in structured form, rather than as raw bytes. + m2 := v.Convert(protocmpMessageType).Interface().(map[string]any) + v, ok := m2[string(genid.Any_Value_field_name)].(proto.Message) + if !ok { + return nil + } + msgVal = v.ProtoReflect() + } else { + val := m.Get(fds.ByNumber(genid.Any_Value_field_number)).Bytes() + mt, err := protoregistry.GlobalTypes.FindMessageByURL(url) + if err != nil { + return nil + } + msgVal = mt.New() + err = proto.UnmarshalOptions{AllowPartial: true}.Unmarshal(val, msgVal.Interface()) + if err != nil { + return nil + } + } + + b = append(b, '{') + b = append(b, "["+url+"]"...) + b = append(b, ':') + b = appendMessage(b, msgVal) + b = append(b, '}') + return b + + case genid.Timestamp_message_fullname: + secs := m.Get(fds.ByNumber(genid.Timestamp_Seconds_field_number)).Int() + nanos := m.Get(fds.ByNumber(genid.Timestamp_Nanos_field_number)).Int() + if nanos < 0 || nanos >= 1e9 { + return nil + } + t := time.Unix(secs, nanos).UTC() + x := t.Format("2006-01-02T15:04:05.000000000") // RFC 3339 + x = strings.TrimSuffix(x, "000") + x = strings.TrimSuffix(x, "000") + x = strings.TrimSuffix(x, ".000") + return append(b, x+"Z"...) + + case genid.Duration_message_fullname: + sign := "" + secs := m.Get(fds.ByNumber(genid.Duration_Seconds_field_number)).Int() + nanos := m.Get(fds.ByNumber(genid.Duration_Nanos_field_number)).Int() + if nanos <= -1e9 || nanos >= 1e9 || (secs > 0 && nanos < 0) || (secs < 0 && nanos > 0) { + return nil + } + if secs < 0 || nanos < 0 { + sign, secs, nanos = "-", -1*secs, -1*nanos + } + x := fmt.Sprintf("%s%d.%09d", sign, secs, nanos) + x = strings.TrimSuffix(x, "000") + x = strings.TrimSuffix(x, "000") + x = strings.TrimSuffix(x, ".000") + return append(b, x+"s"...) + + case genid.BoolValue_message_fullname, + genid.Int32Value_message_fullname, + genid.Int64Value_message_fullname, + genid.UInt32Value_message_fullname, + genid.UInt64Value_message_fullname, + genid.FloatValue_message_fullname, + genid.DoubleValue_message_fullname, + genid.StringValue_message_fullname, + genid.BytesValue_message_fullname: + fd := fds.ByNumber(genid.WrapperValue_Value_field_number) + return appendValue(b, m.Get(fd), fd) + } + + return nil +} + +func appendUnknown(b []byte, raw protoreflect.RawFields) []byte { + rs := make(map[protoreflect.FieldNumber][]protoreflect.RawFields) + for len(raw) > 0 { + num, _, n := protowire.ConsumeField(raw) + rs[num] = append(rs[num], raw[:n]) + raw = raw[n:] + } + + var ns []protoreflect.FieldNumber + for n := range rs { + ns = append(ns, n) + } + sort.Slice(ns, func(i, j int) bool { return ns[i] < ns[j] }) + + for _, n := range ns { + var leftBracket, rightBracket string + if len(rs[n]) > 1 { + leftBracket, rightBracket = "[", "]" + } + + b = strconv.AppendInt(b, int64(n), 10) + b = append(b, ':') + b = append(b, leftBracket...) + for _, r := range rs[n] { + num, typ, n := protowire.ConsumeTag(r) + r = r[n:] + switch typ { + case protowire.VarintType: + v, _ := protowire.ConsumeVarint(r) + b = strconv.AppendInt(b, int64(v), 10) + case protowire.Fixed32Type: + v, _ := protowire.ConsumeFixed32(r) + b = append(b, fmt.Sprintf("0x%08x", v)...) + case protowire.Fixed64Type: + v, _ := protowire.ConsumeFixed64(r) + b = append(b, fmt.Sprintf("0x%016x", v)...) + case protowire.BytesType: + v, _ := protowire.ConsumeBytes(r) + b = strconv.AppendQuote(b, string(v)) + case protowire.StartGroupType: + v, _ := protowire.ConsumeGroup(num, r) + b = append(b, '{') + b = appendUnknown(b, v) + b = bytes.TrimRight(b, delim()) + b = append(b, '}') + default: + panic(fmt.Sprintf("invalid type: %v", typ)) + } + b = append(b, delim()...) + } + b = bytes.TrimRight(b, delim()) + b = append(b, rightBracket...) + b = append(b, delim()...) + } + return b +} + +func appendList(b []byte, v protoreflect.List, fd protoreflect.FieldDescriptor) []byte { + b = append(b, '[') + for i := 0; i < v.Len(); i++ { + b = appendValue(b, v.Get(i), fd) + b = append(b, delim()...) + } + b = bytes.TrimRight(b, delim()) + b = append(b, ']') + return b +} + +func appendMap(b []byte, v protoreflect.Map, fd protoreflect.FieldDescriptor) []byte { + b = append(b, '{') + order.RangeEntries(v, order.GenericKeyOrder, func(k protoreflect.MapKey, v protoreflect.Value) bool { + b = appendValue(b, k.Value(), fd.MapKey()) + b = append(b, ':') + b = appendValue(b, v, fd.MapValue()) + b = append(b, delim()...) + return true + }) + b = bytes.TrimRight(b, delim()) + b = append(b, '}') + return b +} + +func delim() string { + // Deliberately introduce instability into the message string to + // discourage users from depending on it. + if detrand.Bool() { + return " " + } + return ", " +} diff --git a/vendor/google.golang.org/protobuf/testing/protocmp/reflect.go b/vendor/google.golang.org/protobuf/testing/protocmp/reflect.go new file mode 100644 index 00000000000..36f7cd9e6ea --- /dev/null +++ b/vendor/google.golang.org/protobuf/testing/protocmp/reflect.go @@ -0,0 +1,258 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package protocmp + +import ( + "reflect" + "sort" + "strconv" + "strings" + + "google.golang.org/protobuf/internal/genid" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/runtime/protoiface" +) + +func reflectValueOf(v any) protoreflect.Value { + switch v := v.(type) { + case Enum: + return protoreflect.ValueOfEnum(v.Number()) + case Message: + return protoreflect.ValueOfMessage(v.ProtoReflect()) + case []byte: + return protoreflect.ValueOfBytes(v) // avoid overlap with reflect.Slice check below + default: + switch rv := reflect.ValueOf(v); { + case rv.Kind() == reflect.Slice: + return protoreflect.ValueOfList(reflectList{rv}) + case rv.Kind() == reflect.Map: + return protoreflect.ValueOfMap(reflectMap{rv}) + default: + return protoreflect.ValueOf(v) + } + } +} + +type reflectMessage Message + +func (m reflectMessage) stringKey(fd protoreflect.FieldDescriptor) string { + if m.Descriptor() != fd.ContainingMessage() { + panic("mismatching containing message") + } + return fd.TextName() +} + +func (m reflectMessage) Descriptor() protoreflect.MessageDescriptor { + return (Message)(m).Descriptor() +} +func (m reflectMessage) Type() protoreflect.MessageType { + return reflectMessageType{m.Descriptor()} +} +func (m reflectMessage) New() protoreflect.Message { + return m.Type().New() +} +func (m reflectMessage) Interface() protoreflect.ProtoMessage { + return Message(m) +} +func (m reflectMessage) Range(f func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool) { + // Range over populated known fields. + fds := m.Descriptor().Fields() + for i := 0; i < fds.Len(); i++ { + fd := fds.Get(i) + if m.Has(fd) && !f(fd, m.Get(fd)) { + return + } + } + + // Range over populated extension fields. + for _, xd := range m[messageTypeKey].(messageMeta).xds { + if m.Has(xd) && !f(xd, m.Get(xd)) { + return + } + } +} +func (m reflectMessage) Has(fd protoreflect.FieldDescriptor) bool { + _, ok := m[m.stringKey(fd)] + return ok +} +func (m reflectMessage) Clear(protoreflect.FieldDescriptor) { + panic("invalid mutation of read-only message") +} +func (m reflectMessage) Get(fd protoreflect.FieldDescriptor) protoreflect.Value { + v, ok := m[m.stringKey(fd)] + if !ok { + switch { + case fd.IsList(): + return protoreflect.ValueOfList(reflectList{}) + case fd.IsMap(): + return protoreflect.ValueOfMap(reflectMap{}) + case fd.Message() != nil: + return protoreflect.ValueOfMessage(reflectMessage{ + messageTypeKey: messageMeta{md: fd.Message()}, + }) + default: + return fd.Default() + } + } + + // The transformation may leave Any messages in structured form. + // If so, convert them back to a raw-encoded form. + if fd.FullName() == genid.Any_Value_field_fullname { + if m, ok := v.(Message); ok { + b, err := proto.MarshalOptions{Deterministic: true}.Marshal(m) + if err != nil { + panic("BUG: " + err.Error()) + } + return protoreflect.ValueOfBytes(b) + } + } + + return reflectValueOf(v) +} +func (m reflectMessage) Set(protoreflect.FieldDescriptor, protoreflect.Value) { + panic("invalid mutation of read-only message") +} +func (m reflectMessage) Mutable(fd protoreflect.FieldDescriptor) protoreflect.Value { + panic("invalid mutation of read-only message") +} +func (m reflectMessage) NewField(protoreflect.FieldDescriptor) protoreflect.Value { + panic("not implemented") +} +func (m reflectMessage) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor { + if m.Descriptor().Oneofs().ByName(od.Name()) != od { + panic("oneof descriptor does not belong to this message") + } + fds := od.Fields() + for i := 0; i < fds.Len(); i++ { + fd := fds.Get(i) + if _, ok := m[m.stringKey(fd)]; ok { + return fd + } + } + return nil +} +func (m reflectMessage) GetUnknown() protoreflect.RawFields { + var nums []protoreflect.FieldNumber + for k := range m { + if len(strings.Trim(k, "0123456789")) == 0 { + n, _ := strconv.ParseUint(k, 10, 32) + nums = append(nums, protoreflect.FieldNumber(n)) + } + } + sort.Slice(nums, func(i, j int) bool { return nums[i] < nums[j] }) + + var raw protoreflect.RawFields + for _, num := range nums { + b, _ := m[strconv.FormatUint(uint64(num), 10)].(protoreflect.RawFields) + raw = append(raw, b...) + } + return raw +} +func (m reflectMessage) SetUnknown(protoreflect.RawFields) { + panic("invalid mutation of read-only message") +} +func (m reflectMessage) IsValid() bool { + invalid, _ := m[messageInvalidKey].(bool) + return !invalid +} +func (m reflectMessage) ProtoMethods() *protoiface.Methods { + return nil +} + +type reflectMessageType struct{ protoreflect.MessageDescriptor } + +func (t reflectMessageType) New() protoreflect.Message { + panic("not implemented") +} +func (t reflectMessageType) Zero() protoreflect.Message { + panic("not implemented") +} +func (t reflectMessageType) Descriptor() protoreflect.MessageDescriptor { + return t.MessageDescriptor +} + +type reflectList struct{ v reflect.Value } + +func (ls reflectList) Len() int { + if !ls.IsValid() { + return 0 + } + return ls.v.Len() +} +func (ls reflectList) Get(i int) protoreflect.Value { + return reflectValueOf(ls.v.Index(i).Interface()) +} +func (ls reflectList) Set(int, protoreflect.Value) { + panic("invalid mutation of read-only list") +} +func (ls reflectList) Append(protoreflect.Value) { + panic("invalid mutation of read-only list") +} +func (ls reflectList) AppendMutable() protoreflect.Value { + panic("invalid mutation of read-only list") +} +func (ls reflectList) Truncate(int) { + panic("invalid mutation of read-only list") +} +func (ls reflectList) NewElement() protoreflect.Value { + panic("not implemented") +} +func (ls reflectList) IsValid() bool { + return ls.v.IsValid() +} + +type reflectMap struct{ v reflect.Value } + +func (ms reflectMap) Len() int { + if !ms.IsValid() { + return 0 + } + return ms.v.Len() +} +func (ms reflectMap) Range(f func(protoreflect.MapKey, protoreflect.Value) bool) { + if !ms.IsValid() { + return + } + ks := ms.v.MapKeys() + for _, k := range ks { + pk := reflectValueOf(k.Interface()).MapKey() + pv := reflectValueOf(ms.v.MapIndex(k).Interface()) + if !f(pk, pv) { + return + } + } +} +func (ms reflectMap) Has(k protoreflect.MapKey) bool { + if !ms.IsValid() { + return false + } + return ms.v.MapIndex(reflect.ValueOf(k.Interface())).IsValid() +} +func (ms reflectMap) Clear(protoreflect.MapKey) { + panic("invalid mutation of read-only list") +} +func (ms reflectMap) Get(k protoreflect.MapKey) protoreflect.Value { + if !ms.IsValid() { + return protoreflect.Value{} + } + v := ms.v.MapIndex(reflect.ValueOf(k.Interface())) + if !v.IsValid() { + return protoreflect.Value{} + } + return reflectValueOf(v.Interface()) +} +func (ms reflectMap) Set(protoreflect.MapKey, protoreflect.Value) { + panic("invalid mutation of read-only list") +} +func (ms reflectMap) Mutable(k protoreflect.MapKey) protoreflect.Value { + panic("invalid mutation of read-only list") +} +func (ms reflectMap) NewValue() protoreflect.Value { + panic("not implemented") +} +func (ms reflectMap) IsValid() bool { + return ms.v.IsValid() +} diff --git a/vendor/google.golang.org/protobuf/testing/protocmp/util.go b/vendor/google.golang.org/protobuf/testing/protocmp/util.go new file mode 100644 index 00000000000..838e70fbcc9 --- /dev/null +++ b/vendor/google.golang.org/protobuf/testing/protocmp/util.go @@ -0,0 +1,684 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package protocmp + +import ( + "bytes" + "fmt" + "math" + "reflect" + "strings" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" +) + +var ( + enumReflectType = reflect.TypeOf(Enum{}) + messageReflectType = reflect.TypeOf(Message{}) +) + +// FilterEnum filters opt to only be applicable on a standalone [Enum], +// singular fields of enums, list fields of enums, or map fields of enum values, +// where the enum is the same type as the specified enum. +// +// The Go type of the last path step may be an: +// - [Enum] for singular fields, elements of a repeated field, +// values of a map field, or standalone [Enum] values +// - [][Enum] for list fields +// - map[K][Enum] for map fields +// - any for a [Message] map entry value +// +// This must be used in conjunction with [Transform]. +func FilterEnum(enum protoreflect.Enum, opt cmp.Option) cmp.Option { + return FilterDescriptor(enum.Descriptor(), opt) +} + +// FilterMessage filters opt to only be applicable on a standalone [Message] values, +// singular fields of messages, list fields of messages, or map fields of +// message values, where the message is the same type as the specified message. +// +// The Go type of the last path step may be an: +// - [Message] for singular fields, elements of a repeated field, +// values of a map field, or standalone [Message] values +// - [][Message] for list fields +// - map[K][Message] for map fields +// - any for a [Message] map entry value +// +// This must be used in conjunction with [Transform]. +func FilterMessage(message proto.Message, opt cmp.Option) cmp.Option { + return FilterDescriptor(message.ProtoReflect().Descriptor(), opt) +} + +// FilterField filters opt to only be applicable on the specified field +// in the message. It panics if a field of the given name does not exist. +// +// The Go type of the last path step may be an: +// - T for singular fields +// - []T for list fields +// - map[K]T for map fields +// - any for a [Message] map entry value +// +// This must be used in conjunction with [Transform]. +func FilterField(message proto.Message, name protoreflect.Name, opt cmp.Option) cmp.Option { + md := message.ProtoReflect().Descriptor() + return FilterDescriptor(mustFindFieldDescriptor(md, name), opt) +} + +// FilterOneof filters opt to only be applicable on all fields within the +// specified oneof in the message. It panics if a oneof of the given name +// does not exist. +// +// The Go type of the last path step may be an: +// - T for singular fields +// - []T for list fields +// - map[K]T for map fields +// - any for a [Message] map entry value +// +// This must be used in conjunction with [Transform]. +func FilterOneof(message proto.Message, name protoreflect.Name, opt cmp.Option) cmp.Option { + md := message.ProtoReflect().Descriptor() + return FilterDescriptor(mustFindOneofDescriptor(md, name), opt) +} + +// FilterDescriptor ignores the specified descriptor. +// +// The following descriptor types may be specified: +// - [protoreflect.EnumDescriptor] +// - [protoreflect.MessageDescriptor] +// - [protoreflect.FieldDescriptor] +// - [protoreflect.OneofDescriptor] +// +// For the behavior of each, see the corresponding filter function. +// Since this filter accepts a [protoreflect.FieldDescriptor], it can be used +// to also filter for extension fields as a [protoreflect.ExtensionDescriptor] +// is just an alias to [protoreflect.FieldDescriptor]. +// +// This must be used in conjunction with [Transform]. +func FilterDescriptor(desc protoreflect.Descriptor, opt cmp.Option) cmp.Option { + f := newNameFilters(desc) + return cmp.FilterPath(f.Filter, opt) +} + +// IgnoreEnums ignores all enums of the specified types. +// It is equivalent to FilterEnum(enum, cmp.Ignore()) for each enum. +// +// This must be used in conjunction with [Transform]. +func IgnoreEnums(enums ...protoreflect.Enum) cmp.Option { + var ds []protoreflect.Descriptor + for _, e := range enums { + ds = append(ds, e.Descriptor()) + } + return IgnoreDescriptors(ds...) +} + +// IgnoreMessages ignores all messages of the specified types. +// It is equivalent to [FilterMessage](message, [cmp.Ignore]()) for each message. +// +// This must be used in conjunction with [Transform]. +func IgnoreMessages(messages ...proto.Message) cmp.Option { + var ds []protoreflect.Descriptor + for _, m := range messages { + ds = append(ds, m.ProtoReflect().Descriptor()) + } + return IgnoreDescriptors(ds...) +} + +// IgnoreFields ignores the specified fields in the specified message. +// It is equivalent to [FilterField](message, name, [cmp.Ignore]()) for each field +// in the message. +// +// This must be used in conjunction with [Transform]. +func IgnoreFields(message proto.Message, names ...protoreflect.Name) cmp.Option { + var ds []protoreflect.Descriptor + md := message.ProtoReflect().Descriptor() + for _, s := range names { + ds = append(ds, mustFindFieldDescriptor(md, s)) + } + return IgnoreDescriptors(ds...) +} + +// IgnoreOneofs ignores fields of the specified oneofs in the specified message. +// It is equivalent to FilterOneof(message, name, cmp.Ignore()) for each oneof +// in the message. +// +// This must be used in conjunction with [Transform]. +func IgnoreOneofs(message proto.Message, names ...protoreflect.Name) cmp.Option { + var ds []protoreflect.Descriptor + md := message.ProtoReflect().Descriptor() + for _, s := range names { + ds = append(ds, mustFindOneofDescriptor(md, s)) + } + return IgnoreDescriptors(ds...) +} + +// IgnoreDescriptors ignores the specified set of descriptors. +// It is equivalent to [FilterDescriptor](desc, [cmp.Ignore]()) for each descriptor. +// +// This must be used in conjunction with [Transform]. +func IgnoreDescriptors(descs ...protoreflect.Descriptor) cmp.Option { + return cmp.FilterPath(newNameFilters(descs...).Filter, cmp.Ignore()) +} + +func mustFindFieldDescriptor(md protoreflect.MessageDescriptor, s protoreflect.Name) protoreflect.FieldDescriptor { + d := findDescriptor(md, s) + if fd, ok := d.(protoreflect.FieldDescriptor); ok && fd.TextName() == string(s) { + return fd + } + + var suggestion string + switch d := d.(type) { + case protoreflect.FieldDescriptor: + suggestion = fmt.Sprintf("; consider specifying field %q instead", d.TextName()) + case protoreflect.OneofDescriptor: + suggestion = fmt.Sprintf("; consider specifying oneof %q with IgnoreOneofs instead", d.Name()) + } + panic(fmt.Sprintf("message %q has no field %q%s", md.FullName(), s, suggestion)) +} + +func mustFindOneofDescriptor(md protoreflect.MessageDescriptor, s protoreflect.Name) protoreflect.OneofDescriptor { + d := findDescriptor(md, s) + if od, ok := d.(protoreflect.OneofDescriptor); ok && d.Name() == s { + return od + } + + var suggestion string + switch d := d.(type) { + case protoreflect.OneofDescriptor: + suggestion = fmt.Sprintf("; consider specifying oneof %q instead", d.Name()) + case protoreflect.FieldDescriptor: + suggestion = fmt.Sprintf("; consider specifying field %q with IgnoreFields instead", d.TextName()) + } + panic(fmt.Sprintf("message %q has no oneof %q%s", md.FullName(), s, suggestion)) +} + +func findDescriptor(md protoreflect.MessageDescriptor, s protoreflect.Name) protoreflect.Descriptor { + // Exact match. + if fd := md.Fields().ByTextName(string(s)); fd != nil { + return fd + } + if od := md.Oneofs().ByName(s); od != nil && !od.IsSynthetic() { + return od + } + + // Best-effort match. + // + // It's a common user mistake to use the CamelCased field name as it appears + // in the generated Go struct. Instead of complaining that it doesn't exist, + // suggest the real protobuf name that the user may have desired. + normalize := func(s protoreflect.Name) string { + return strings.Replace(strings.ToLower(string(s)), "_", "", -1) + } + for i := 0; i < md.Fields().Len(); i++ { + if fd := md.Fields().Get(i); normalize(fd.Name()) == normalize(s) { + return fd + } + } + for i := 0; i < md.Oneofs().Len(); i++ { + if od := md.Oneofs().Get(i); normalize(od.Name()) == normalize(s) { + return od + } + } + return nil +} + +type nameFilters struct { + names map[protoreflect.FullName]bool +} + +func newNameFilters(descs ...protoreflect.Descriptor) *nameFilters { + f := &nameFilters{names: make(map[protoreflect.FullName]bool)} + for _, d := range descs { + switch d := d.(type) { + case protoreflect.EnumDescriptor: + f.names[d.FullName()] = true + case protoreflect.MessageDescriptor: + f.names[d.FullName()] = true + case protoreflect.FieldDescriptor: + f.names[d.FullName()] = true + case protoreflect.OneofDescriptor: + for i := 0; i < d.Fields().Len(); i++ { + f.names[d.Fields().Get(i).FullName()] = true + } + default: + panic("invalid descriptor type") + } + } + return f +} + +func (f *nameFilters) Filter(p cmp.Path) bool { + vx, vy := p.Last().Values() + return (f.filterValue(vx) && f.filterValue(vy)) || f.filterFields(p) +} + +func (f *nameFilters) filterFields(p cmp.Path) bool { + // Trim off trailing type-assertions so that the filter can match on the + // concrete value held within an interface value. + if _, ok := p.Last().(cmp.TypeAssertion); ok { + p = p[:len(p)-1] + } + + // Filter for Message maps. + mi, ok := p.Index(-1).(cmp.MapIndex) + if !ok { + return false + } + ps := p.Index(-2) + if ps.Type() != messageReflectType { + return false + } + + // Check field name. + vx, vy := ps.Values() + mx := vx.Interface().(Message) + my := vy.Interface().(Message) + k := mi.Key().String() + if f.filterFieldName(mx, k) && f.filterFieldName(my, k) { + return true + } + + // Check field value. + vx, vy = mi.Values() + if f.filterFieldValue(vx) && f.filterFieldValue(vy) { + return true + } + + return false +} + +func (f *nameFilters) filterFieldName(m Message, k string) bool { + if _, ok := m[k]; !ok { + return true // treat missing fields as already filtered + } + var fd protoreflect.FieldDescriptor + switch mm := m[messageTypeKey].(messageMeta); { + case protoreflect.Name(k).IsValid(): + fd = mm.md.Fields().ByTextName(k) + default: + fd = mm.xds[k] + } + if fd != nil { + return f.names[fd.FullName()] + } + return false +} + +func (f *nameFilters) filterFieldValue(v reflect.Value) bool { + if !v.IsValid() { + return true // implies missing slice element or map entry + } + v = v.Elem() // map entries are always populated values + switch t := v.Type(); { + case t == enumReflectType || t == messageReflectType: + // Check for singular message or enum field. + return f.filterValue(v) + case t.Kind() == reflect.Slice && (t.Elem() == enumReflectType || t.Elem() == messageReflectType): + // Check for list field of enum or message type. + return f.filterValue(v.Index(0)) + case t.Kind() == reflect.Map && (t.Elem() == enumReflectType || t.Elem() == messageReflectType): + // Check for map field of enum or message type. + return f.filterValue(v.MapIndex(v.MapKeys()[0])) + } + return false +} + +func (f *nameFilters) filterValue(v reflect.Value) bool { + if !v.IsValid() { + return true // implies missing slice element or map entry + } + if !v.CanInterface() { + return false // implies unexported struct field + } + switch v := v.Interface().(type) { + case Enum: + return v.Descriptor() != nil && f.names[v.Descriptor().FullName()] + case Message: + return v.Descriptor() != nil && f.names[v.Descriptor().FullName()] + } + return false +} + +// IgnoreDefaultScalars ignores singular scalars that are unpopulated or +// explicitly set to the default value. +// This option does not effect elements in a list or entries in a map. +// +// This must be used in conjunction with [Transform]. +func IgnoreDefaultScalars() cmp.Option { + return cmp.FilterPath(func(p cmp.Path) bool { + // Filter for Message maps. + mi, ok := p.Index(-1).(cmp.MapIndex) + if !ok { + return false + } + ps := p.Index(-2) + if ps.Type() != messageReflectType { + return false + } + + // Check whether both fields are default or unpopulated scalars. + vx, vy := ps.Values() + mx := vx.Interface().(Message) + my := vy.Interface().(Message) + k := mi.Key().String() + return isDefaultScalar(mx, k) && isDefaultScalar(my, k) + }, cmp.Ignore()) +} + +func isDefaultScalar(m Message, k string) bool { + if _, ok := m[k]; !ok { + return true + } + + var fd protoreflect.FieldDescriptor + switch mm := m[messageTypeKey].(messageMeta); { + case protoreflect.Name(k).IsValid(): + fd = mm.md.Fields().ByTextName(k) + default: + fd = mm.xds[k] + } + if fd == nil || !fd.Default().IsValid() { + return false + } + switch fd.Kind() { + case protoreflect.BytesKind: + v, ok := m[k].([]byte) + return ok && bytes.Equal(fd.Default().Bytes(), v) + case protoreflect.FloatKind: + v, ok := m[k].(float32) + return ok && equalFloat64(fd.Default().Float(), float64(v)) + case protoreflect.DoubleKind: + v, ok := m[k].(float64) + return ok && equalFloat64(fd.Default().Float(), float64(v)) + case protoreflect.EnumKind: + v, ok := m[k].(Enum) + return ok && fd.Default().Enum() == v.Number() + default: + return reflect.DeepEqual(fd.Default().Interface(), m[k]) + } +} + +func equalFloat64(x, y float64) bool { + return x == y || (math.IsNaN(x) && math.IsNaN(y)) +} + +// IgnoreEmptyMessages ignores messages that are empty or unpopulated. +// It applies to standalone [Message] values, singular message fields, +// list fields of messages, and map fields of message values. +// +// This must be used in conjunction with [Transform]. +func IgnoreEmptyMessages() cmp.Option { + return cmp.FilterPath(func(p cmp.Path) bool { + vx, vy := p.Last().Values() + return (isEmptyMessage(vx) && isEmptyMessage(vy)) || isEmptyMessageFields(p) + }, cmp.Ignore()) +} + +func isEmptyMessageFields(p cmp.Path) bool { + // Filter for Message maps. + mi, ok := p.Index(-1).(cmp.MapIndex) + if !ok { + return false + } + ps := p.Index(-2) + if ps.Type() != messageReflectType { + return false + } + + // Check field value. + vx, vy := mi.Values() + if isEmptyMessageFieldValue(vx) && isEmptyMessageFieldValue(vy) { + return true + } + + return false +} + +func isEmptyMessageFieldValue(v reflect.Value) bool { + if !v.IsValid() { + return true // implies missing slice element or map entry + } + v = v.Elem() // map entries are always populated values + switch t := v.Type(); { + case t == messageReflectType: + // Check singular field for empty message. + if !isEmptyMessage(v) { + return false + } + case t.Kind() == reflect.Slice && t.Elem() == messageReflectType: + // Check list field for all empty message elements. + for i := 0; i < v.Len(); i++ { + if !isEmptyMessage(v.Index(i)) { + return false + } + } + case t.Kind() == reflect.Map && t.Elem() == messageReflectType: + // Check map field for all empty message values. + for _, k := range v.MapKeys() { + if !isEmptyMessage(v.MapIndex(k)) { + return false + } + } + default: + return false + } + return true +} + +func isEmptyMessage(v reflect.Value) bool { + if !v.IsValid() { + return true // implies missing slice element or map entry + } + if !v.CanInterface() { + return false // implies unexported struct field + } + if m, ok := v.Interface().(Message); ok { + for k := range m { + if k != messageTypeKey && k != messageInvalidKey { + return false + } + } + return true + } + return false +} + +// IgnoreUnknown ignores unknown fields in all messages. +// +// This must be used in conjunction with [Transform]. +func IgnoreUnknown() cmp.Option { + return cmp.FilterPath(func(p cmp.Path) bool { + // Filter for Message maps. + mi, ok := p.Index(-1).(cmp.MapIndex) + if !ok { + return false + } + ps := p.Index(-2) + if ps.Type() != messageReflectType { + return false + } + + // Filter for unknown fields (which always have a numeric map key). + return strings.Trim(mi.Key().String(), "0123456789") == "" + }, cmp.Ignore()) +} + +// SortRepeated sorts repeated fields of the specified element type. +// The less function must be of the form "func(T, T) bool" where T is the +// Go element type for the repeated field kind. +// +// The element type T can be one of the following: +// - Go type for a protobuf scalar kind except for an enum +// (i.e., bool, int32, int64, uint32, uint64, float32, float64, string, and []byte) +// - E where E is a concrete enum type that implements [protoreflect.Enum] +// - M where M is a concrete message type that implement [proto.Message] +// +// This option only applies to repeated fields within a protobuf message. +// It does not operate on higher-order Go types that seem like a repeated field. +// For example, a []T outside the context of a protobuf message will not be +// handled by this option. To sort Go slices that are not repeated fields, +// consider using [github.com/google/go-cmp/cmp/cmpopts.SortSlices] instead. +// +// This must be used in conjunction with [Transform]. +func SortRepeated(lessFunc any) cmp.Option { + t, ok := checkTTBFunc(lessFunc) + if !ok { + panic(fmt.Sprintf("invalid less function: %T", lessFunc)) + } + + var opt cmp.Option + var sliceType reflect.Type + switch vf := reflect.ValueOf(lessFunc); { + case t.Implements(enumV2Type): + et := reflect.Zero(t).Interface().(protoreflect.Enum).Type() + lessFunc = func(x, y Enum) bool { + vx := reflect.ValueOf(et.New(x.Number())) + vy := reflect.ValueOf(et.New(y.Number())) + return vf.Call([]reflect.Value{vx, vy})[0].Bool() + } + opt = FilterDescriptor(et.Descriptor(), cmpopts.SortSlices(lessFunc)) + sliceType = reflect.SliceOf(enumReflectType) + case t.Implements(messageV2Type): + mt := reflect.Zero(t).Interface().(protoreflect.ProtoMessage).ProtoReflect().Type() + lessFunc = func(x, y Message) bool { + mx := mt.New().Interface() + my := mt.New().Interface() + proto.Merge(mx, x) + proto.Merge(my, y) + vx := reflect.ValueOf(mx) + vy := reflect.ValueOf(my) + return vf.Call([]reflect.Value{vx, vy})[0].Bool() + } + opt = FilterDescriptor(mt.Descriptor(), cmpopts.SortSlices(lessFunc)) + sliceType = reflect.SliceOf(messageReflectType) + default: + switch t { + case reflect.TypeOf(bool(false)): + case reflect.TypeOf(int32(0)): + case reflect.TypeOf(int64(0)): + case reflect.TypeOf(uint32(0)): + case reflect.TypeOf(uint64(0)): + case reflect.TypeOf(float32(0)): + case reflect.TypeOf(float64(0)): + case reflect.TypeOf(string("")): + case reflect.TypeOf([]byte(nil)): + default: + panic(fmt.Sprintf("invalid element type: %v", t)) + } + opt = cmpopts.SortSlices(lessFunc) + sliceType = reflect.SliceOf(t) + } + + return cmp.FilterPath(func(p cmp.Path) bool { + // Filter to only apply to repeated fields within a message. + if t := p.Index(-1).Type(); t == nil || t != sliceType { + return false + } + if t := p.Index(-2).Type(); t == nil || t.Kind() != reflect.Interface { + return false + } + if t := p.Index(-3).Type(); t == nil || t != messageReflectType { + return false + } + return true + }, opt) +} + +func checkTTBFunc(lessFunc any) (reflect.Type, bool) { + switch t := reflect.TypeOf(lessFunc); { + case t == nil: + return nil, false + case t.NumIn() != 2 || t.In(0) != t.In(1) || t.IsVariadic(): + return nil, false + case t.NumOut() != 1 || t.Out(0) != reflect.TypeOf(false): + return nil, false + default: + return t.In(0), true + } +} + +// SortRepeatedFields sorts the specified repeated fields. +// Sorting a repeated field is useful for treating the list as a multiset +// (i.e., a set where each value can appear multiple times). +// It panics if the field does not exist or is not a repeated field. +// +// The sort ordering is as follows: +// - Booleans are sorted where false is sorted before true. +// - Integers are sorted in ascending order. +// - Floating-point numbers are sorted in ascending order according to +// the total ordering defined by IEEE-754 (section 5.10). +// - Strings and bytes are sorted lexicographically in ascending order. +// - [Enum] values are sorted in ascending order based on its numeric value. +// - [Message] values are sorted according to some arbitrary ordering +// which is undefined and may change in future implementations. +// +// The ordering chosen for repeated messages is unlikely to be aesthetically +// preferred by humans. Consider using a custom sort function: +// +// FilterField(m, "foo_field", SortRepeated(func(x, y *foopb.MyMessage) bool { +// ... // user-provided definition for less +// })) +// +// This must be used in conjunction with [Transform]. +func SortRepeatedFields(message proto.Message, names ...protoreflect.Name) cmp.Option { + var opts cmp.Options + md := message.ProtoReflect().Descriptor() + for _, name := range names { + fd := mustFindFieldDescriptor(md, name) + if !fd.IsList() { + panic(fmt.Sprintf("message field %q is not repeated", fd.FullName())) + } + + var lessFunc any + switch fd.Kind() { + case protoreflect.BoolKind: + lessFunc = func(x, y bool) bool { return !x && y } + case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: + lessFunc = func(x, y int32) bool { return x < y } + case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: + lessFunc = func(x, y int64) bool { return x < y } + case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: + lessFunc = func(x, y uint32) bool { return x < y } + case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: + lessFunc = func(x, y uint64) bool { return x < y } + case protoreflect.FloatKind: + lessFunc = lessF32 + case protoreflect.DoubleKind: + lessFunc = lessF64 + case protoreflect.StringKind: + lessFunc = func(x, y string) bool { return x < y } + case protoreflect.BytesKind: + lessFunc = func(x, y []byte) bool { return bytes.Compare(x, y) < 0 } + case protoreflect.EnumKind: + lessFunc = func(x, y Enum) bool { return x.Number() < y.Number() } + case protoreflect.MessageKind, protoreflect.GroupKind: + lessFunc = func(x, y Message) bool { return x.String() < y.String() } + default: + panic(fmt.Sprintf("invalid kind: %v", fd.Kind())) + } + opts = append(opts, FilterDescriptor(fd, cmpopts.SortSlices(lessFunc))) + } + return opts +} + +func lessF32(x, y float32) bool { + // Bit-wise implementation of IEEE-754, section 5.10. + xi := int32(math.Float32bits(x)) + yi := int32(math.Float32bits(y)) + xi ^= int32(uint32(xi>>31) >> 1) + yi ^= int32(uint32(yi>>31) >> 1) + return xi < yi +} +func lessF64(x, y float64) bool { + // Bit-wise implementation of IEEE-754, section 5.10. + xi := int64(math.Float64bits(x)) + yi := int64(math.Float64bits(y)) + xi ^= int64(uint64(xi>>63) >> 1) + yi ^= int64(uint64(yi>>63) >> 1) + return xi < yi +} diff --git a/vendor/google.golang.org/protobuf/testing/protocmp/xform.go b/vendor/google.golang.org/protobuf/testing/protocmp/xform.go new file mode 100644 index 00000000000..de29d973269 --- /dev/null +++ b/vendor/google.golang.org/protobuf/testing/protocmp/xform.go @@ -0,0 +1,377 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package protocmp provides protobuf specific options for the +// [github.com/google/go-cmp/cmp] package. +// +// The primary feature is the [Transform] option, which transform [proto.Message] +// types into a [Message] map that is suitable for cmp to introspect upon. +// All other options in this package must be used in conjunction with [Transform]. +package protocmp + +import ( + "reflect" + "strconv" + + "github.com/google/go-cmp/cmp" + + "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/internal/genid" + "google.golang.org/protobuf/internal/msgfmt" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" + "google.golang.org/protobuf/runtime/protoiface" + "google.golang.org/protobuf/runtime/protoimpl" +) + +var ( + enumV2Type = reflect.TypeOf((*protoreflect.Enum)(nil)).Elem() + messageV1Type = reflect.TypeOf((*protoiface.MessageV1)(nil)).Elem() + messageV2Type = reflect.TypeOf((*proto.Message)(nil)).Elem() +) + +// Enum is a dynamic representation of a protocol buffer enum that is +// suitable for [cmp.Equal] and [cmp.Diff] to compare upon. +type Enum struct { + num protoreflect.EnumNumber + ed protoreflect.EnumDescriptor +} + +// Descriptor returns the enum descriptor. +// It returns nil for a zero Enum value. +func (e Enum) Descriptor() protoreflect.EnumDescriptor { + return e.ed +} + +// Number returns the enum value as an integer. +func (e Enum) Number() protoreflect.EnumNumber { + return e.num +} + +// Equal reports whether e1 and e2 represent the same enum value. +func (e1 Enum) Equal(e2 Enum) bool { + if e1.ed.FullName() != e2.ed.FullName() { + return false + } + return e1.num == e2.num +} + +// String returns the name of the enum value if known (e.g., "ENUM_VALUE"), +// otherwise it returns the formatted decimal enum number (e.g., "14"). +func (e Enum) String() string { + if ev := e.ed.Values().ByNumber(e.num); ev != nil { + return string(ev.Name()) + } + return strconv.Itoa(int(e.num)) +} + +const ( + // messageTypeKey indicates the protobuf message type. + // The value type is always messageMeta. + // From the public API, it presents itself as only the type, but the + // underlying data structure holds arbitrary metadata about the message. + messageTypeKey = "@type" + + // messageInvalidKey indicates that the message is invalid. + // The value is always the boolean "true". + messageInvalidKey = "@invalid" +) + +type messageMeta struct { + m proto.Message + md protoreflect.MessageDescriptor + xds map[string]protoreflect.ExtensionDescriptor +} + +func (t messageMeta) String() string { + return string(t.md.FullName()) +} + +func (t1 messageMeta) Equal(t2 messageMeta) bool { + return t1.md.FullName() == t2.md.FullName() +} + +// Message is a dynamic representation of a protocol buffer message that is +// suitable for [cmp.Equal] and [cmp.Diff] to directly operate upon. +// +// Every populated known field (excluding extension fields) is stored in the map +// with the key being the short name of the field (e.g., "field_name") and +// the value determined by the kind and cardinality of the field. +// +// Singular scalars are represented by the same Go type as [protoreflect.Value], +// singular messages are represented by the [Message] type, +// singular enums are represented by the [Enum] type, +// list fields are represented as a Go slice, and +// map fields are represented as a Go map. +// +// Every populated extension field is stored in the map with the key being the +// full name of the field surrounded by brackets (e.g., "[extension.full.name]") +// and the value determined according to the same rules as known fields. +// +// Every unknown field is stored in the map with the key being the field number +// encoded as a decimal string (e.g., "132") and the value being the raw bytes +// of the encoded field (as the [protoreflect.RawFields] type). +// +// Message values must not be created by or mutated by users. +type Message map[string]any + +// Unwrap returns the original message value. +// It returns nil if this Message was not constructed from another message. +func (m Message) Unwrap() proto.Message { + mm, _ := m[messageTypeKey].(messageMeta) + return mm.m +} + +// Descriptor return the message descriptor. +// It returns nil for a zero Message value. +func (m Message) Descriptor() protoreflect.MessageDescriptor { + mm, _ := m[messageTypeKey].(messageMeta) + return mm.md +} + +// ProtoReflect returns a reflective view of m. +// It only implements the read-only operations of [protoreflect.Message]. +// Calling any mutating operations on m panics. +func (m Message) ProtoReflect() protoreflect.Message { + return (reflectMessage)(m) +} + +// ProtoMessage is a marker method from the legacy message interface. +func (m Message) ProtoMessage() {} + +// Reset is the required Reset method from the legacy message interface. +func (m Message) Reset() { + panic("invalid mutation of a read-only message") +} + +// String returns a formatted string for the message. +// It is intended for human debugging and has no guarantees about its +// exact format or the stability of its output. +func (m Message) String() string { + switch { + case m == nil: + return "" + case !m.ProtoReflect().IsValid(): + return "" + default: + return msgfmt.Format(m) + } +} + +type transformer struct { + resolver protoregistry.MessageTypeResolver +} + +func newTransformer(opts ...option) *transformer { + xf := &transformer{ + resolver: protoregistry.GlobalTypes, + } + for _, opt := range opts { + opt(xf) + } + return xf +} + +type option func(*transformer) + +// MessageTypeResolver overrides the resolver used for messages packed +// inside Any. The default is protoregistry.GlobalTypes, which is +// sufficient for all compiled-in Protobuf messages. Overriding the +// resolver is useful in tests that dynamically create Protobuf +// descriptors and messages, e.g. in proxies using dynamicpb. +func MessageTypeResolver(r protoregistry.MessageTypeResolver) option { + return func(xf *transformer) { + xf.resolver = r + } +} + +// Transform returns a [cmp.Option] that converts each [proto.Message] to a [Message]. +// The transformation does not mutate nor alias any converted messages. +// +// The google.protobuf.Any message is automatically unmarshaled such that the +// "value" field is a [Message] representing the underlying message value +// assuming it could be resolved and properly unmarshaled. +// +// This does not directly transform higher-order composite Go types. +// For example, []*foopb.Message is not transformed into []Message, +// but rather the individual message elements of the slice are transformed. +func Transform(opts ...option) cmp.Option { + xf := newTransformer(opts...) + + // addrType returns a pointer to t if t isn't a pointer or interface. + addrType := func(t reflect.Type) reflect.Type { + if k := t.Kind(); k == reflect.Interface || k == reflect.Ptr { + return t + } + return reflect.PtrTo(t) + } + + // TODO: Should this transform protoreflect.Enum types to Enum as well? + return cmp.FilterPath(func(p cmp.Path) bool { + ps := p.Last() + if isMessageType(addrType(ps.Type())) { + return true + } + + // Check whether the concrete values of an interface both satisfy + // the Message interface. + if ps.Type().Kind() == reflect.Interface { + vx, vy := ps.Values() + if !vx.IsValid() || vx.IsNil() || !vy.IsValid() || vy.IsNil() { + return false + } + return isMessageType(addrType(vx.Elem().Type())) && isMessageType(addrType(vy.Elem().Type())) + } + + return false + }, cmp.Transformer("protocmp.Transform", func(v any) Message { + // For user convenience, shallow copy the message value if necessary + // in order for it to implement the message interface. + if rv := reflect.ValueOf(v); rv.IsValid() && rv.Kind() != reflect.Ptr && !isMessageType(rv.Type()) { + pv := reflect.New(rv.Type()) + pv.Elem().Set(rv) + v = pv.Interface() + } + + m := protoimpl.X.MessageOf(v) + switch { + case m == nil: + return nil + case !m.IsValid(): + return Message{messageTypeKey: messageMeta{m: m.Interface(), md: m.Descriptor()}, messageInvalidKey: true} + default: + return xf.transformMessage(m) + } + })) +} + +func isMessageType(t reflect.Type) bool { + // Avoid transforming the Message itself. + if t == reflect.TypeOf(Message(nil)) || t == reflect.TypeOf((*Message)(nil)) { + return false + } + return t.Implements(messageV1Type) || t.Implements(messageV2Type) +} + +func (xf *transformer) transformMessage(m protoreflect.Message) Message { + mx := Message{} + mt := messageMeta{m: m.Interface(), md: m.Descriptor(), xds: make(map[string]protoreflect.FieldDescriptor)} + + // Handle known and extension fields. + m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { + s := fd.TextName() + if fd.IsExtension() { + mt.xds[s] = fd + } + switch { + case fd.IsList(): + mx[s] = xf.transformList(fd, v.List()) + case fd.IsMap(): + mx[s] = xf.transformMap(fd, v.Map()) + default: + mx[s] = xf.transformSingular(fd, v) + } + return true + }) + + // Handle unknown fields. + for b := m.GetUnknown(); len(b) > 0; { + num, _, n := protowire.ConsumeField(b) + s := strconv.Itoa(int(num)) + b2, _ := mx[s].(protoreflect.RawFields) + mx[s] = append(b2, b[:n]...) + b = b[n:] + } + + // Expand Any messages. + if mt.md.FullName() == genid.Any_message_fullname { + s, _ := mx[string(genid.Any_TypeUrl_field_name)].(string) + b, _ := mx[string(genid.Any_Value_field_name)].([]byte) + mt, err := xf.resolver.FindMessageByURL(s) + if mt != nil && err == nil { + m2 := mt.New() + err := proto.UnmarshalOptions{AllowPartial: true}.Unmarshal(b, m2.Interface()) + if err == nil { + mx[string(genid.Any_Value_field_name)] = xf.transformMessage(m2) + } + } + } + + mx[messageTypeKey] = mt + return mx +} + +func (xf *transformer) transformList(fd protoreflect.FieldDescriptor, lv protoreflect.List) any { + t := protoKindToGoType(fd.Kind()) + rv := reflect.MakeSlice(reflect.SliceOf(t), lv.Len(), lv.Len()) + for i := 0; i < lv.Len(); i++ { + v := reflect.ValueOf(xf.transformSingular(fd, lv.Get(i))) + rv.Index(i).Set(v) + } + return rv.Interface() +} + +func (xf *transformer) transformMap(fd protoreflect.FieldDescriptor, mv protoreflect.Map) any { + kfd := fd.MapKey() + vfd := fd.MapValue() + kt := protoKindToGoType(kfd.Kind()) + vt := protoKindToGoType(vfd.Kind()) + rv := reflect.MakeMapWithSize(reflect.MapOf(kt, vt), mv.Len()) + mv.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool { + kv := reflect.ValueOf(xf.transformSingular(kfd, k.Value())) + vv := reflect.ValueOf(xf.transformSingular(vfd, v)) + rv.SetMapIndex(kv, vv) + return true + }) + return rv.Interface() +} + +func (xf *transformer) transformSingular(fd protoreflect.FieldDescriptor, v protoreflect.Value) any { + switch fd.Kind() { + case protoreflect.EnumKind: + return Enum{num: v.Enum(), ed: fd.Enum()} + case protoreflect.MessageKind, protoreflect.GroupKind: + return xf.transformMessage(v.Message()) + case protoreflect.BytesKind: + // The protoreflect API does not specify whether an empty bytes is + // guaranteed to be nil or not. Always return non-nil bytes to avoid + // leaking information about the concrete proto.Message implementation. + if len(v.Bytes()) == 0 { + return []byte{} + } + return v.Bytes() + default: + return v.Interface() + } +} + +func protoKindToGoType(k protoreflect.Kind) reflect.Type { + switch k { + case protoreflect.BoolKind: + return reflect.TypeOf(bool(false)) + case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: + return reflect.TypeOf(int32(0)) + case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: + return reflect.TypeOf(int64(0)) + case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: + return reflect.TypeOf(uint32(0)) + case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: + return reflect.TypeOf(uint64(0)) + case protoreflect.FloatKind: + return reflect.TypeOf(float32(0)) + case protoreflect.DoubleKind: + return reflect.TypeOf(float64(0)) + case protoreflect.StringKind: + return reflect.TypeOf(string("")) + case protoreflect.BytesKind: + return reflect.TypeOf([]byte(nil)) + case protoreflect.EnumKind: + return reflect.TypeOf(Enum{}) + case protoreflect.MessageKind, protoreflect.GroupKind: + return reflect.TypeOf(Message{}) + default: + panic("invalid kind") + } +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 8205ec5fed1..1ab40eb36c3 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -2353,6 +2353,7 @@ google.golang.org/protobuf/internal/filetype google.golang.org/protobuf/internal/flags google.golang.org/protobuf/internal/genid google.golang.org/protobuf/internal/impl +google.golang.org/protobuf/internal/msgfmt google.golang.org/protobuf/internal/order google.golang.org/protobuf/internal/pragma google.golang.org/protobuf/internal/set @@ -2365,6 +2366,7 @@ google.golang.org/protobuf/reflect/protoreflect google.golang.org/protobuf/reflect/protoregistry google.golang.org/protobuf/runtime/protoiface google.golang.org/protobuf/runtime/protoimpl +google.golang.org/protobuf/testing/protocmp google.golang.org/protobuf/types/descriptorpb google.golang.org/protobuf/types/gofeaturespb google.golang.org/protobuf/types/known/anypb