Skip to content

Commit

Permalink
support putifabsent (#8428)
Browse files Browse the repository at this point in the history
Added support to if-none-match header for s3gateway's putObject and completeMultipartUpload
  • Loading branch information
ItamarYuran authored Jan 1, 2025
1 parent 54977e5 commit bde54c0
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 15 deletions.
6 changes: 3 additions & 3 deletions esti/multipart_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func TestMultipartUpload(t *testing.T) {
partsConcat = append(partsConcat, parts[i]...)
}

completedParts := uploadMultipartParts(t, ctx, logger, resp, parts, 0)
completedParts := uploadMultipartParts(t, ctx, svc, logger, resp, parts, 0)

if isBlockstoreType(block.BlockstoreTypeS3) == nil {
// Object should have Last-Modified time at around time of MPU creation. Ensure
Expand Down Expand Up @@ -166,7 +166,7 @@ func reverse(s string) string {
return string(runes)
}

func uploadMultipartParts(t *testing.T, ctx context.Context, logger logging.Logger, resp *s3.CreateMultipartUploadOutput, parts [][]byte, firstIndex int) []types.CompletedPart {
func uploadMultipartParts(t *testing.T, ctx context.Context, client *s3.Client, logger logging.Logger, resp *s3.CreateMultipartUploadOutput, parts [][]byte, firstIndex int) []types.CompletedPart {
count := len(parts)
completedParts := make([]types.CompletedPart, count)
errs := make([]error, count)
Expand All @@ -176,7 +176,7 @@ func uploadMultipartParts(t *testing.T, ctx context.Context, logger logging.Logg
go func(i int) {
defer wg.Done()
partNumber := firstIndex + i + 1
completedParts[i], errs[i] = uploadMultipartPart(ctx, logger, svc, resp, parts[i], partNumber)
completedParts[i], errs[i] = uploadMultipartPart(ctx, logger, client, resp, parts[i], partNumber)
}(i)
}
wg.Wait()
Expand Down
121 changes: 114 additions & 7 deletions esti/s3_gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@ import (
"bytes"
"context"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/go-openapi/swag"
"io"
"math/rand"
"net/http"
Expand All @@ -16,6 +13,14 @@ import (
"testing"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
"github.com/go-openapi/swag"
"github.com/thanhpk/randstr"

"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
"github.com/minio/minio-go/v7/pkg/tags"
Expand All @@ -31,10 +36,12 @@ import (
type GetCredentials = func(id, secret, token string) *credentials.Credentials

const (
numUploads = 100
randomDataPathLength = 1020
branch = "main"
gatewayTestPrefix = branch + "/data/"
numUploads = 100
randomDataPathLength = 1020
branch = "main"
gatewayTestPrefix = branch + "/data/"
errorPreconditionFailed = "At least one of the pre-conditions you specified did not hold"
errorNotImplemented = "A header you provided implies functionality that is not implemented"
)

func newMinioClient(t *testing.T, getCredentials GetCredentials) *minio.Client {
Expand Down Expand Up @@ -181,6 +188,106 @@ func TestS3UploadAndDownload(t *testing.T) {
})
}
}
func TestMultipartUploadIfNoneMatch(t *testing.T) {
ctx, logger, repo := setupTest(t)
defer tearDownTest(repo)
s3Endpoint := viper.GetString("s3_endpoint")
s3Client := createS3Client(s3Endpoint, t)
multipartNumberOfParts := 7
multipartPartSize := 5 * 1024 * 1024
type TestCase struct {
Path string
IfNoneMatch string
ExpectedError string
}

testCases := []TestCase{
{Path: "main/object1"},
{Path: "main/object1", IfNoneMatch: "*", ExpectedError: errorPreconditionFailed},
{Path: "main/object2", IfNoneMatch: "*"},
}
for _, tc := range testCases {
input := &s3.CreateMultipartUploadInput{
Bucket: aws.String(repo),
Key: aws.String(tc.Path),
}

resp, err := s3Client.CreateMultipartUpload(ctx, input)
require.NoError(t, err, "failed to create multipart upload")

parts := make([][]byte, multipartNumberOfParts)
for i := 0; i < multipartNumberOfParts; i++ {
parts[i] = randstr.Bytes(multipartPartSize + i)
}

completedParts := uploadMultipartParts(t, ctx, s3Client, logger, resp, parts, 0)

completeInput := &s3.CompleteMultipartUploadInput{
Bucket: resp.Bucket,
Key: resp.Key,
UploadId: resp.UploadId,
MultipartUpload: &types.CompletedMultipartUpload{
Parts: completedParts,
},
}
_, err = s3Client.CompleteMultipartUpload(ctx, completeInput, s3.WithAPIOptions(setHTTPHeaders(tc.IfNoneMatch)))
if tc.ExpectedError != "" {
require.ErrorContains(t, err, tc.ExpectedError)
} else {
require.NoError(t, err, "expected no error but got %w")
}
}
}

func setHTTPHeaders(ifNoneMatch string) func(*middleware.Stack) error {
return func(stack *middleware.Stack) error {
return stack.Build.Add(middleware.BuildMiddlewareFunc("AddIfNoneMatchHeader", func(
ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
) (
middleware.BuildOutput, middleware.Metadata, error,
) {
if req, ok := in.Request.(*smithyhttp.Request); ok {
// Add the If-None-Match header
req.Header.Set("If-None-Match", ifNoneMatch)
}
return next.HandleBuild(ctx, in)
}), middleware.Before)
}
}
func TestS3IfNoneMatch(t *testing.T) {

ctx, _, repo := setupTest(t)
defer tearDownTest(repo)

s3Endpoint := viper.GetString("s3_endpoint")
s3Client := createS3Client(s3Endpoint, t)

type TestCase struct {
Path string
IfNoneMatch string
ExpectedError string
}

testCases := []TestCase{
{Path: "main/object1"},
{Path: "main/object1", IfNoneMatch: "*", ExpectedError: errorPreconditionFailed},
{Path: "main/object2", IfNoneMatch: "*"},
{Path: "main/object2"},
{Path: "main/object3", IfNoneMatch: "unsupported string", ExpectedError: errorNotImplemented},
}
for _, tc := range testCases {
input := &s3.PutObjectInput{
Bucket: aws.String(repo),
Key: aws.String(tc.Path),
}
_, err := s3Client.PutObject(ctx, input, s3.WithAPIOptions(setHTTPHeaders(tc.IfNoneMatch)))
if tc.ExpectedError != "" {
require.ErrorContains(t, err, tc.ExpectedError)
} else {
require.NoError(t, err, "expected no error but got %w")
}
}
}

func verifyObjectInfo(t *testing.T, got minio.ObjectInfo, expectedSize int) {
if got.Err != nil {
Expand Down
5 changes: 3 additions & 2 deletions pkg/gateway/operations/operation_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"time"

"github.com/treeverse/lakefs/pkg/catalog"
"github.com/treeverse/lakefs/pkg/graveler"
"github.com/treeverse/lakefs/pkg/logging"
)

Expand Down Expand Up @@ -40,7 +41,7 @@ func shouldReplaceMetadata(req *http.Request) bool {
return req.Header.Get(amzMetadataDirectiveHeaderPrefix) == "REPLACE"
}

func (o *PathOperation) finishUpload(req *http.Request, mTime *time.Time, checksum, physicalAddress string, size int64, relative bool, metadata map[string]string, contentType string) error {
func (o *PathOperation) finishUpload(req *http.Request, mTime *time.Time, checksum, physicalAddress string, size int64, relative bool, metadata map[string]string, contentType string, allowOverwrite bool) error {
var writeTime time.Time
if mTime == nil {
writeTime = time.Now()
Expand All @@ -59,7 +60,7 @@ func (o *PathOperation) finishUpload(req *http.Request, mTime *time.Time, checks
ContentType(contentType).
Build()

err := o.Catalog.CreateEntry(req.Context(), o.Repository.Name, o.Reference, entry)
err := o.Catalog.CreateEntry(req.Context(), o.Repository.Name, o.Reference, entry, graveler.WithIfAbsent(!allowOverwrite))
if err != nil {
o.Log(req).WithError(err).Error("could not update metadata")
return err
Expand Down
24 changes: 23 additions & 1 deletion pkg/gateway/operations/postobject.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"time"

"github.com/treeverse/lakefs/pkg/block"
"github.com/treeverse/lakefs/pkg/catalog"
gatewayErrors "github.com/treeverse/lakefs/pkg/gateway/errors"
"github.com/treeverse/lakefs/pkg/gateway/multipart"
"github.com/treeverse/lakefs/pkg/gateway/path"
Expand Down Expand Up @@ -94,6 +95,23 @@ func (controller *PostObject) HandleCompleteMultipartUpload(w http.ResponseWrite
_ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrInternalError))
return
}
// check and validate whether if-none-match header provided
allowOverwrite, err := o.checkIfAbsent(req)
if err != nil {
_ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrNotImplemented))
return
}
// before writing body, ensure preconditions - this means we essentially check for object existence twice:
// once here, before uploading the body to save resources and time,
// and then graveler will check again when passed a SetOptions.
if !allowOverwrite {
_, err := o.Catalog.GetEntry(req.Context(), o.Repository.Name, o.Reference, o.Path, catalog.GetEntryParams{})
if err == nil {
// In case object exists in catalog, no error returns
_ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrPreconditionFailed))
return
}
}
objName := multiPart.PhysicalAddress
req = req.WithContext(logging.AddFields(req.Context(), logging.Fields{logging.PhysicalAddressFieldKey: objName}))
xmlMultipartComplete, err := io.ReadAll(req.Body)
Expand Down Expand Up @@ -124,7 +142,11 @@ func (controller *PostObject) HandleCompleteMultipartUpload(w http.ResponseWrite
return
}
checksum := strings.Split(resp.ETag, "-")[0]
err = o.finishUpload(req, resp.MTime, checksum, objName, resp.ContentLength, true, multiPart.Metadata, multiPart.ContentType)
err = o.finishUpload(req, resp.MTime, checksum, objName, resp.ContentLength, true, multiPart.Metadata, multiPart.ContentType, allowOverwrite)
if errors.Is(err, graveler.ErrPreconditionFailed) {
_ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrPreconditionFailed))
return
}
if errors.Is(err, graveler.ErrWriteToProtectedBranch) {
_ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrWriteToProtectedBranch))
return
Expand Down
36 changes: 34 additions & 2 deletions pkg/gateway/operations/putobject.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
)

const (
IfNoneMatchHeader = "If-None-Match"
CopySourceHeader = "x-amz-copy-source"
CopySourceRangeHeader = "x-amz-copy-source-range"
QueryParamUploadID = "uploadId"
Expand All @@ -30,7 +31,6 @@ type PutObject struct{}

func (controller *PutObject) RequiredPermissions(req *http.Request, repoID, _, destPath string) (permissions.Node, error) {
copySource := req.Header.Get(CopySourceHeader)

if len(copySource) == 0 {
return permissions.Node{
Permission: permissions.Permission{
Expand Down Expand Up @@ -298,6 +298,23 @@ func handlePut(w http.ResponseWriter, req *http.Request, o *PathOperation) {
o.Incr("put_object", o.Principal, o.Repository.Name, o.Reference)
storageClass := StorageClassFromHeader(req.Header)
opts := block.PutOpts{StorageClass: storageClass}
// check and validate whether if-none-match header provided
allowOverwrite, err := o.checkIfAbsent(req)
if err != nil {
_ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrNotImplemented))
return
}
// before writing body, ensure preconditions - this means we essentially check for object existence twice:
// once here, before uploading the body to save resources and time,
// and then graveler will check again when passed a SetOptions.
if !allowOverwrite {
_, err := o.Catalog.GetEntry(req.Context(), o.Repository.Name, o.Reference, o.Path, catalog.GetEntryParams{})
if err == nil {
// In case object exists in catalog, no error returns
_ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrPreconditionFailed))
return
}
}
address := o.PathProvider.NewPath()
blob, err := upload.WriteBlob(req.Context(), o.BlockStore, o.Repository.StorageNamespace, address, req.Body, req.ContentLength, opts)
if err != nil {
Expand All @@ -309,7 +326,11 @@ func handlePut(w http.ResponseWriter, req *http.Request, o *PathOperation) {
// write metadata
metadata := amzMetaAsMetadata(req)
contentType := req.Header.Get("Content-Type")
err = o.finishUpload(req, &blob.CreationDate, blob.Checksum, blob.PhysicalAddress, blob.Size, true, metadata, contentType)
err = o.finishUpload(req, &blob.CreationDate, blob.Checksum, blob.PhysicalAddress, blob.Size, true, metadata, contentType, allowOverwrite)
if errors.Is(err, graveler.ErrPreconditionFailed) {
_ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrPreconditionFailed))
return
}
if errors.Is(err, graveler.ErrWriteToProtectedBranch) {
_ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrWriteToProtectedBranch))
return
Expand All @@ -325,3 +346,14 @@ func handlePut(w http.ResponseWriter, req *http.Request, o *PathOperation) {
o.SetHeader(w, "ETag", httputil.ETag(blob.Checksum))
w.WriteHeader(http.StatusOK)
}

func (o *PathOperation) checkIfAbsent(req *http.Request) (bool, error) {
headerValue := req.Header.Get(IfNoneMatchHeader)
if headerValue == "" {
return true, nil
}
if headerValue == "*" {
return false, nil
}
return false, gatewayErrors.ErrNotImplemented
}

0 comments on commit bde54c0

Please sign in to comment.