Skip to content

Commit

Permalink
Return response headers from SaveData call (#119)
Browse files Browse the repository at this point in the history
This will allow us to debug uploads better with our storage providers, passing them requestIDs etc
  • Loading branch information
mjh1 authored Jan 29, 2024
1 parent bcf005e commit 4dd4b5b
Show file tree
Hide file tree
Showing 12 changed files with 66 additions and 49 deletions.
19 changes: 12 additions & 7 deletions drivers/drivers.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ type FileProperties struct {
ContentType string
}

type SaveDataOutput struct {
URL string
UploaderResponseHeaders http.Header
}

var AvailableDrivers = []OSDriver{
&FSOS{},
&GsOS{},
Expand Down Expand Up @@ -133,7 +138,7 @@ const (
type OSSession interface {
OS() OSDriver

SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (string, error)
SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (*SaveDataOutput, error)
EndSession()

// Info in order to have this session used via RPC
Expand Down Expand Up @@ -309,19 +314,19 @@ func ParseOSURL(input string, useFullAPI bool) (OSDriver, error) {
}

// SaveRetried tries to SaveData specified number of times
func SaveRetried(ctx context.Context, sess OSSession, name string, data []byte, fields *FileProperties, retryCount int) (string, error) {
func SaveRetried(ctx context.Context, sess OSSession, name string, data []byte, fields *FileProperties, retryCount int) (*SaveDataOutput, error) {
if retryCount < 1 {
return "", fmt.Errorf("invalid retry count %d", retryCount)
return nil, fmt.Errorf("invalid retry count %d", retryCount)
}
var uri string
var out *SaveDataOutput
var err error
for i := 0; i < retryCount; i++ {
uri, err = sess.SaveData(ctx, name, bytes.NewReader(data), fields, 0)
out, err = sess.SaveData(ctx, name, bytes.NewReader(data), fields, 0)
if err == nil {
return uri, err
return out, err
}
}
return uri, err
return out, err
}

var httpc = &http.Client{
Expand Down
14 changes: 7 additions & 7 deletions drivers/fs.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,35 +174,35 @@ func (ostore *FSSession) GetInfo() *OSInfo {
return nil
}

func (ostore *FSSession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (string, error) {
func (ostore *FSSession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (*SaveDataOutput, error) {
fullPath := ostore.getAbsoluteURI(name)
dir, name := path.Split(fullPath)
err := os.MkdirAll(dir, os.ModePerm)
if err != nil {
return "", err
return nil, err
}
file, err := os.Create(fullPath)
if err != nil {
return "", err
return nil, err
}
buf := make([]byte, 128*1024)
defer file.Close()
for {
select {
case <-ctx.Done():
return "", ctx.Err()
return nil, ctx.Err()
default:
read, err := data.Read(buf)
if err != nil && err != io.EOF {
return "", err
return nil, err
}
if read > 0 {
_, err = file.Write(buf[:read])
if err != nil {
return "", err
return nil, err
}
} else {
return fullPath, nil
return &SaveDataOutput{URL: fullPath}, nil
}
}
}
Expand Down
6 changes: 4 additions & 2 deletions drivers/fs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ func TestFsOS(t *testing.T) {
assert.NoError((err))
storage := NewFSDriver(u)
sess := storage.NewSession("driver-test").(*FSSession)
path, err := sess.SaveData(context.TODO(), "name1/1.ts", bytes.NewReader(rndData), nil, 0)
out, err := sess.SaveData(context.TODO(), "name1/1.ts", bytes.NewReader(rndData), nil, 0)
assert.NoError(err)
path := out.URL
defer os.Remove(path)
assert.Equal("/tmp/driver-test/name1/1.ts", path)
data := readFile(sess, "driver-test/name1/1.ts")
Expand All @@ -52,8 +53,9 @@ func TestFsOS(t *testing.T) {
// Test trim prefix when baseURI = nil
storage = NewFSDriver(nil)
sess = storage.NewSession("/tmp/").(*FSSession)
path, err = sess.SaveData(context.TODO(), "driver-test/name1/1.ts", bytes.NewReader(rndData), nil, 0)
out, err = sess.SaveData(context.TODO(), "driver-test/name1/1.ts", bytes.NewReader(rndData), nil, 0)
assert.NoError(err)
path = out.URL
defer os.Remove(path)
assert.Equal("/tmp/driver-test/name1/1.ts", path)
data = readFile(sess, path)
Expand Down
12 changes: 6 additions & 6 deletions drivers/gs.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,11 @@ func (os *gsSession) DeleteFile(ctx context.Context, name string) error {
Delete(ctx)
}

func (os *gsSession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (string, error) {
func (os *gsSession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (*SaveDataOutput, error) {
if os.useFullAPI {
if os.client == nil {
if err := os.createClient(); err != nil {
return "", err
return nil, err
}
}
keyname := os.key + "/" + name
Expand All @@ -201,19 +201,19 @@ func (os *gsSession) SaveData(ctx context.Context, name string, data io.Reader,
}
data, contentType, err := os.peekContentType(name, data)
if err != nil {
return "", err
return nil, err
}
wr.ContentType = contentType
_, err = io.Copy(wr, data)
err2 := wr.Close()
if err != nil {
return "", err
return nil, err
}
if err2 != nil {
return "", err2
return nil, err2
}
uri := os.getAbsURL(keyname)
return uri, err
return &SaveDataOutput{URL: uri}, err
}
return os.s3Session.SaveData(ctx, name, data, fields, timeout)
}
Expand Down
4 changes: 2 additions & 2 deletions drivers/ipfs.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,15 @@ func (ostore *IpfsSession) DeleteFile(ctx context.Context, name string) error {
return ErrNotSupported
}

func (session *IpfsSession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (string, error) {
func (session *IpfsSession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (*SaveDataOutput, error) {
// concatenate filename with name argument to get full filename, both may be empty
fullPath := session.getAbsolutePath(name)
if fullPath == "" {
// pinata requires name to be set
fullPath = "data.bin"
}
cid, _, err := session.client.PinContent(ctx, fullPath, "", data)
return cid, err
return &SaveDataOutput{URL: cid}, err
}

func (session *IpfsSession) getAbsolutePath(name string) string {
Expand Down
3 changes: 2 additions & 1 deletion drivers/ipfs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ func TestIpfsOS(t *testing.T) {
assert := assert.New(t)
storage := NewIpfsDriver(pinataKey, pinataSecret)
sess := storage.NewSession("").(*IpfsSession)
cid, err := sess.SaveData(context.TODO(), fileName, bytes.NewReader(rndData), nil, 0)
out, err := sess.SaveData(context.TODO(), fileName, bytes.NewReader(rndData), nil, 0)
assert.NoError(err)
cid := out.URL
// first, list file through API
files, err := sess.ListFiles(context.TODO(), cid, "")
assert.NoError(err)
Expand Down
8 changes: 4 additions & 4 deletions drivers/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,24 +222,24 @@ func (ostore *MemoryOS) Description() string {
return "Memory driver."
}

func (ostore *MemorySession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (string, error) {
func (ostore *MemorySession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (*SaveDataOutput, error) {
path, file := path.Split(ostore.getAbsolutePath(name))

ostore.dLock.Lock()
defer ostore.dLock.Unlock()

if ostore.ended {
return "", fmt.Errorf("Session ended")
return nil, fmt.Errorf("Session ended")
}

bytes, err := ioutil.ReadAll(data)
if err != nil {
return "", err
return nil, err
}
dc := ostore.getCacheForStream(path)
dc.Insert(file, bytes)

return ostore.getAbsoluteURI(name), nil
return &SaveDataOutput{URL: ostore.getAbsoluteURI(name)}, nil
}

func (ostore *MemorySession) getCacheForStream(streamID string) *dataCache {
Expand Down
9 changes: 6 additions & 3 deletions drivers/local_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ func TestLocalOS(t *testing.T) {

os := NewMemoryDriver(u)
sess := os.NewSession(("sesspath")).(*MemorySession)
path, err := sess.SaveData(context.TODO(), "name1/1.ts", strings.NewReader(tempData1), nil, 0)
out, err := sess.SaveData(context.TODO(), "name1/1.ts", strings.NewReader(tempData1), nil, 0)
require.NoError(t, err)
path := out.URL
require.Equal(t, "fake.com/url/stream/sesspath/name1/1.ts", path)

data := sess.GetData("sesspath/name1/1.ts")
Expand All @@ -37,8 +38,9 @@ func TestLocalOS(t *testing.T) {
data = sess.GetData("sesspath/name1/1.ts")
require.Equal(t, tempData2, string(data))

path, err = sess.SaveData(context.TODO(), "name1/2.ts", strings.NewReader(tempData3), nil, 0)
out, err = sess.SaveData(context.TODO(), "name1/2.ts", strings.NewReader(tempData3), nil, 0)
require.NoError(t, err)
path = out.URL

data = sess.GetData("sesspath/name1/2.ts")
require.Equal(t, tempData3, string(data))
Expand All @@ -56,8 +58,9 @@ func TestLocalOS(t *testing.T) {
// Test trim prefix when baseURI = nil
os = NewMemoryDriver(nil)
sess = os.NewSession("sesspath").(*MemorySession)
path, err = sess.SaveData(context.TODO(), "name1/1.ts", strings.NewReader(tempData1), nil, 0)
out, err = sess.SaveData(context.TODO(), "name1/1.ts", strings.NewReader(tempData1), nil, 0)
require.NoError(t, err)
path = out.URL
require.Equal(t, "/stream/sesspath/name1/1.ts", path)

data = sess.GetData(path)
Expand Down
21 changes: 13 additions & 8 deletions drivers/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"github.com/aws/aws-sdk-go/aws/request"
"io"
"mime/multipart"
"net/http"
Expand Down Expand Up @@ -370,7 +371,7 @@ func (os *s3Session) ReadDataRange(ctx context.Context, name, byteRange string)
return res, nil
}

func (os *s3Session) saveDataPut(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (string, error) {
func (os *s3Session) saveDataPut(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (*SaveDataOutput, error) {
bucket := aws.String(os.bucket)
keyname := aws.String(path.Join(os.key, name))
var metadata map[string]*string
Expand All @@ -382,15 +383,17 @@ func (os *s3Session) saveDataPut(ctx context.Context, name string, data io.Reade
}
data, contentType, err := os.peekContentType(name, data)
if err != nil {
return "", err
return nil, err
}
if fields != nil && fields.ContentType != "" {
contentType = fields.ContentType
}

respHeaders := http.Header{}
uploader := s3manager.NewUploader(os.s3sess, func(u *s3manager.Uploader) {
u.Concurrency = uploaderConcurrency
u.PartSize = uploaderPartSize
u.RequestOptions = append(u.RequestOptions, request.WithGetResponseHeaders(&respHeaders))
})
params := &s3manager.UploadInput{
Bucket: bucket,
Expand All @@ -409,11 +412,13 @@ func (os *s3Session) saveDataPut(ctx context.Context, name string, data io.Reade
_, err = uploader.UploadWithContext(ctx, params)
cancel()
if err != nil {
return "", err
return nil, err
}

url := os.getAbsURL(*keyname)
return url, nil
return &SaveDataOutput{
URL: os.getAbsURL(*keyname),
UploaderResponseHeaders: respHeaders,
}, nil
}

func (os *s3Session) DeleteFile(ctx context.Context, name string) error {
Expand All @@ -431,19 +436,19 @@ func (os *s3Session) DeleteFile(ctx context.Context, name string) error {
return err
}

func (os *s3Session) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (string, error) {
func (os *s3Session) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (*SaveDataOutput, error) {
if os.s3svc != nil {
return os.saveDataPut(ctx, name, data, fields, timeout)
}
_ = path.Join(os.host, os.key, name)
path, err := os.postData(ctx, name, data, fields, timeout)
if err != nil {
// handle error
return "", err
return nil, err
}

url := os.getAbsURL(path)
return url, nil
return &SaveDataOutput{URL: url}, nil
}

func (os *s3Session) getAbsURL(path string) string {
Expand Down
3 changes: 2 additions & 1 deletion drivers/s3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ func S3UploadTest(require *require.Assertions, fullUriStr, saveName string) {
require.NoError(err)

session := os.NewSession("")
outUriStr, err := session.SaveData(context.Background(), saveName, bytes.NewReader(testData), nil, 10*time.Second)
out, err := session.SaveData(context.Background(), saveName, bytes.NewReader(testData), nil, 10*time.Second)
require.NoError(err)
outUriStr := out.URL

var data *FileInfoReader
// for specific key session, saveName is empty, otherwise, it's the key
Expand Down
4 changes: 2 additions & 2 deletions drivers/session_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ func NewMockOSSession() *MockOSSession {
}
}

func (s *MockOSSession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (string, error) {
func (s *MockOSSession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (*SaveDataOutput, error) {
args := s.Called(name, data, fields, timeout)
if s.waitForCh {
s.back <- struct{}{}
<-s.waitCh
s.waitForCh = false
}
return args.String(0), args.Error(1)
return &SaveDataOutput{URL: args.String(0)}, args.Error(1)
}

func (s *MockOSSession) EndSession() {
Expand Down
12 changes: 6 additions & 6 deletions drivers/w3s.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func (session *W3sSession) DeleteFile(ctx context.Context, name string) error {
return ErrNotSupported
}

func (session *W3sSession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (string, error) {
func (session *W3sSession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (*SaveDataOutput, error) {
if timeout <= 0 {
timeout = w3SDefaultSaveTimeout
}
Expand All @@ -145,27 +145,27 @@ func (session *W3sSession) SaveData(ctx context.Context, name string, data io.Re

filePath, err := toFile(data)
if err != nil {
return "", err
return nil, err
}
defer deleteFile(filePath)

carPath, fileCid, err := ipfsCarPack(ctx, filePath)
if err != nil {
return "", err
return nil, err
}
defer deleteFile(carPath)

carCid, err := w3StoreCar(ctx, session.os.ucanProof, carPath)
if err != nil {
return "", err
return nil, err
}

rCar := session.os.getRootCar()
if err = rCar.addFile(ctx, session.os.dirPath, name, fileCid, carCid); err != nil {
return "", err
return nil, err
}

return fileCid, nil
return &SaveDataOutput{URL: fileCid}, nil
}

func (rc *rootCar) addFile(ctx context.Context, dirPath, filename, fileCid, carCid string) error {
Expand Down

0 comments on commit 4dd4b5b

Please sign in to comment.