Skip to content

Commit

Permalink
Listening to when the client cancels the connection to stop operation…
Browse files Browse the repository at this point in the history
…s on the store
  • Loading branch information
ItalyPaleAle committed May 27, 2020
1 parent 9e8c647 commit b44064c
Show file tree
Hide file tree
Showing 10 changed files with 124 additions and 26 deletions.
4 changes: 3 additions & 1 deletion cmd/add.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
package cmd

import (
"context"
"fmt"
"path/filepath"
"strings"
Expand Down Expand Up @@ -107,6 +108,7 @@ You must specify a destination, which is a folder inside the repository where yo
}

// Iterate through the args and add them all
ctx := context.Background()
res := make(chan repository.PathResultMessage)
go func() {
var err error
Expand All @@ -125,7 +127,7 @@ You must specify a destination, which is a folder inside the repository where yo
folder := filepath.Dir(expanded)
target := filepath.Base(expanded)

repo.AddPath(folder, target, flagDestination, res)
repo.AddPath(ctx, folder, target, flagDestination, res)
}

close(res)
Expand Down
18 changes: 13 additions & 5 deletions fs/fs-azure-storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func (f *AzureStorage) GetInfoFile() (info *infofile.InfoFile, err error) {
blockBlobURL := azblob.NewBlockBlobURL(*u, f.storagePipeline)

// Download the file
resp, err := blockBlobURL.Download(context.TODO(), 0, azblob.CountToEnd, azblob.BlobAccessConditions{}, false)
resp, err := blockBlobURL.Download(context.Background(), 0, azblob.CountToEnd, azblob.BlobAccessConditions{}, false)
if err != nil {
if stgErr, ok := err.(azblob.StorageError); !ok {
err = fmt.Errorf("network error while downloading the file: %s", err.Error())
Expand Down Expand Up @@ -158,7 +158,7 @@ func (f *AzureStorage) SetInfoFile(info *infofile.InfoFile) (err error) {
blockBlobURL := azblob.NewBlockBlobURL(*u, f.storagePipeline)

// Upload
_, err = azblob.UploadBufferToBlockBlob(context.TODO(), data, blockBlobURL, azblob.UploadToBlockBlobOptions{})
_, err = azblob.UploadBufferToBlockBlob(context.Background(), data, blockBlobURL, azblob.UploadToBlockBlobOptions{})
if err != nil {
if stgErr, ok := err.(azblob.StorageError); !ok {
return fmt.Errorf("network error while uploading the file: %s", err.Error())
Expand All @@ -171,6 +171,10 @@ func (f *AzureStorage) SetInfoFile(info *infofile.InfoFile) (err error) {
}

func (f *AzureStorage) Get(name string, out io.Writer, metadataCb crypto.MetadataCb) (found bool, tag interface{}, err error) {
return f.GetWithContext(context.Background(), name, out, metadataCb)
}

func (f *AzureStorage) GetWithContext(ctx context.Context, name string, out io.Writer, metadataCb crypto.MetadataCb) (found bool, tag interface{}, err error) {
if name == "" {
err = errors.New("name is empty")
return
Expand All @@ -192,7 +196,7 @@ func (f *AzureStorage) Get(name string, out io.Writer, metadataCb crypto.Metadat
blockBlobURL := azblob.NewBlockBlobURL(*u, f.storagePipeline)

// Download the file
resp, err := blockBlobURL.Download(context.TODO(), 0, azblob.CountToEnd, azblob.BlobAccessConditions{}, false)
resp, err := blockBlobURL.Download(ctx, 0, azblob.CountToEnd, azblob.BlobAccessConditions{}, false)
if err != nil {
if stgErr, ok := err.(azblob.StorageError); !ok {
err = fmt.Errorf("network error while downloading the file: %s", err.Error())
Expand Down Expand Up @@ -232,6 +236,10 @@ func (f *AzureStorage) Get(name string, out io.Writer, metadataCb crypto.Metadat
}

func (f *AzureStorage) Set(name string, in io.Reader, tag interface{}, metadata *crypto.Metadata) (tagOut interface{}, err error) {
return f.SetWithContext(context.Background(), name, in, tag, metadata)
}

func (f *AzureStorage) SetWithContext(ctx context.Context, name string, in io.Reader, tag interface{}, metadata *crypto.Metadata) (tagOut interface{}, err error) {
if name == "" {
err = errors.New("name is empty")
return
Expand Down Expand Up @@ -282,7 +290,7 @@ func (f *AzureStorage) Set(name string, in io.Reader, tag interface{}, metadata
}
}

resp, err := azblob.UploadStreamToBlockBlob(context.TODO(), pr, blockBlobURL, azblob.UploadStreamToBlockBlobOptions{
resp, err := azblob.UploadStreamToBlockBlob(ctx, pr, blockBlobURL, azblob.UploadStreamToBlockBlobOptions{
BufferSize: 3 * 1024 * 1024,
MaxBuffers: 2,
AccessConditions: accessConditions,
Expand Down Expand Up @@ -334,6 +342,6 @@ func (f *AzureStorage) Delete(name string, tag interface{}) (err error) {
}

// Delete the blob
_, err = blockBlobURL.Delete(context.TODO(), azblob.DeleteSnapshotsOptionInclude, accessConditions)
_, err = blockBlobURL.Delete(context.Background(), azblob.DeleteSnapshotsOptionInclude, accessConditions)
return
}
13 changes: 11 additions & 2 deletions fs/fs-local.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
package fs

import (
"context"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -138,6 +139,10 @@ func (f *Local) SetInfoFile(info *infofile.InfoFile) (err error) {
}

func (f *Local) Get(name string, out io.Writer, metadataCb crypto.MetadataCb) (found bool, tag interface{}, err error) {
return f.GetWithContext(context.Background(), name, out, metadataCb)
}

func (f *Local) GetWithContext(ctx context.Context, name string, out io.Writer, metadataCb crypto.MetadataCb) (found bool, tag interface{}, err error) {
if name == "" {
err = errors.New("name is empty")
return
Expand Down Expand Up @@ -172,7 +177,7 @@ func (f *Local) Get(name string, out io.Writer, metadataCb crypto.MetadataCb) (f
}

// Decrypt the data
err = crypto.DecryptFile(out, file, f.masterKey, metadataCb)
err = crypto.DecryptFile(utils.WriterFuncWithContext(ctx, out), file, f.masterKey, metadataCb)
if err != nil {
return
}
Expand All @@ -181,6 +186,10 @@ func (f *Local) Get(name string, out io.Writer, metadataCb crypto.MetadataCb) (f
}

func (f *Local) Set(name string, in io.Reader, tag interface{}, metadata *crypto.Metadata) (tagOut interface{}, err error) {
return f.SetWithContext(context.Background(), name, in, tag, metadata)
}

func (f *Local) SetWithContext(ctx context.Context, name string, in io.Reader, tag interface{}, metadata *crypto.Metadata) (tagOut interface{}, err error) {
if name == "" {
err = errors.New("name is empty")
return
Expand All @@ -205,7 +214,7 @@ func (f *Local) Set(name string, in io.Reader, tag interface{}, metadata *crypto
}

// Encrypt the data and write it to file
err = crypto.EncryptFile(file, in, f.masterKey, metadata)
err = crypto.EncryptFile(file, utils.ReaderFuncWithContext(ctx, in), f.masterKey, metadata)
if err != nil {
return nil, err
}
Expand Down
13 changes: 11 additions & 2 deletions fs/fs-s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package fs

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -143,6 +144,10 @@ func (f *S3) SetInfoFile(info *infofile.InfoFile) (err error) {
}

func (f *S3) Get(name string, out io.Writer, metadataCb crypto.MetadataCb) (found bool, tag interface{}, err error) {
return f.GetWithContext(context.Background(), name, out, metadataCb)
}

func (f *S3) GetWithContext(ctx context.Context, name string, out io.Writer, metadataCb crypto.MetadataCb) (found bool, tag interface{}, err error) {
if name == "" {
err = errors.New("name is empty")
return
Expand All @@ -157,7 +162,7 @@ func (f *S3) Get(name string, out io.Writer, metadataCb crypto.MetadataCb) (foun
found = true

// Request the file from S3
obj, err := f.client.GetObject(f.bucketName, folder+name, minio.GetObjectOptions{})
obj, err := f.client.GetObjectWithContext(ctx, f.bucketName, folder+name, minio.GetObjectOptions{})
if err != nil {
return
}
Expand All @@ -179,6 +184,10 @@ func (f *S3) Get(name string, out io.Writer, metadataCb crypto.MetadataCb) (foun
}

func (f *S3) Set(name string, in io.Reader, tag interface{}, metadata *crypto.Metadata) (tagOut interface{}, err error) {
return f.SetWithContext(context.Background(), name, in, tag, metadata)
}

func (f *S3) SetWithContext(ctx context.Context, name string, in io.Reader, tag interface{}, metadata *crypto.Metadata) (tagOut interface{}, err error) {
if name == "" {
err = errors.New("name is empty")
return nil, err
Expand All @@ -199,7 +208,7 @@ func (f *S3) Set(name string, in io.Reader, tag interface{}, metadata *crypto.Me
}
pw.Close()
}()
_, err = f.client.PutObject(f.bucketName, folder+name, pr, -1, minio.PutObjectOptions{})
_, err = f.client.PutObjectWithContext(ctx, f.bucketName, folder+name, pr, -1, minio.PutObjectOptions{})
if err != nil {
return nil, err
}
Expand Down
7 changes: 7 additions & 0 deletions fs/fs.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
package fs

import (
"context"
"fmt"
"io"
"strings"
Expand Down Expand Up @@ -75,10 +76,16 @@ type Fs interface {
// It also returns a tag (which might be empty) that should be passed to the Set method if you want to subsequentially update the contents of the file
Get(name string, out io.Writer, metadataCb crypto.MetadataCb) (found bool, tag interface{}, err error)

// GetWithContext is like Get, but accepts a custom context
GetWithContext(ctx context.Context, name string, out io.Writer, metadataCb crypto.MetadataCb) (found bool, tag interface{}, err error)

// Set writes a stream to the file in the filesystem
// If you pass a tag, the implementation might use that to ensure that the file on the filesystem hasn't been changed since it was read (optional)
Set(name string, in io.Reader, tag interface{}, metadata *crypto.Metadata) (tagOut interface{}, err error)

// SetWithContext is like Set, but accepts a custom context
SetWithContext(ctx context.Context, name string, in io.Reader, tag interface{}, metadata *crypto.Metadata) (tagOut interface{}, err error)

// Delete a file from the filesystem
// If you pass a tag, the implementation might use that to ensure that the file on the filesystem hasn't been changed since it was read (optional)
Delete(name string, tag interface{}) (err error)
Expand Down
15 changes: 8 additions & 7 deletions repository/add.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
package repository

import (
"context"
"errors"
"io"
"os"
Expand All @@ -32,7 +33,7 @@ import (
)

// AddStream adds a document to the repository by reading it from a stream
func (repo *Repository) AddStream(in io.ReadCloser, filename, destinationFolder, mimeType string, size int64) (int, error) {
func (repo *Repository) AddStream(ctx context.Context, in io.ReadCloser, filename, destinationFolder, mimeType string, size int64) (int, error) {
// Generate a file id
fileId, err := uuid.NewV4()
if err != nil {
Expand Down Expand Up @@ -61,7 +62,7 @@ func (repo *Repository) AddStream(in io.ReadCloser, filename, destinationFolder,
ContentType: mimeType,
Size: size,
}
_, err = repo.Store.Set(fileId.String(), in, nil, metadata)
_, err = repo.Store.SetWithContext(ctx, fileId.String(), in, nil, metadata)
if err != nil {
return RepositoryStatusInternalError, err
}
Expand All @@ -77,7 +78,7 @@ func (repo *Repository) AddStream(in io.ReadCloser, filename, destinationFolder,

// AddFile adds a file to the repository
// This accepts any regular file, and it does not ignore any file
func (repo *Repository) AddFile(folder, target, destinationFolder string) (int, error) {
func (repo *Repository) AddFile(ctx context.Context, folder, target, destinationFolder string) (int, error) {
path := filepath.Join(folder, target)

// Check if target exists and it's a regular file
Expand Down Expand Up @@ -111,11 +112,11 @@ func (repo *Repository) AddFile(folder, target, destinationFolder string) (int,
size := stat.Size()

// Add the file's stream
return repo.AddStream(in, target, destinationFolder, mimeType, size)
return repo.AddStream(ctx, in, target, destinationFolder, mimeType, size)
}

// AddPath adds a path (a file or a folder, recursively) and reports each element added in the res channel
func (repo *Repository) AddPath(folder, target, destinationFolder string, res chan<- PathResultMessage) {
func (repo *Repository) AddPath(ctx context.Context, folder, target, destinationFolder string, res chan<- PathResultMessage) {
path := filepath.Join(folder, target)

// Check if target exists
Expand Down Expand Up @@ -159,7 +160,7 @@ func (repo *Repository) AddPath(folder, target, destinationFolder string, res ch

// For files, add that
if isFile {
status, err := repo.AddFile(folder, target, destinationFolder)
status, err := repo.AddFile(ctx, folder, target, destinationFolder)
res <- PathResultMessage{
Path: destinationFolder + target,
Status: status,
Expand Down Expand Up @@ -190,7 +191,7 @@ func (repo *Repository) AddPath(folder, target, destinationFolder string, res ch
}
for _, el := range list {
// Recursion
repo.AddPath(path, el.Name(), destinationFolder+target+"/", res)
repo.AddPath(ctx, path, el.Name(), destinationFolder+target+"/", res)
}
}

Expand Down
15 changes: 9 additions & 6 deletions server/api-post-tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
package server

import (
"context"
"errors"
"mime/multipart"
"net/http"
Expand All @@ -35,6 +36,8 @@ import (
// - A file transmitted in the request body (in the "file" field)
// - The path to a file or folder in the local filesystem (in the "localpath" field(s))
func (s *Server) PostTreeHandler(c *gin.Context) {
ctx := c.Request.Context()

// Get the path (can be empty if targeting the root)
path := c.Param("path")
// Ensure that the path starts with / and ends with "/"
Expand All @@ -54,9 +57,9 @@ func (s *Server) PostTreeHandler(c *gin.Context) {
localPaths := c.PostFormArray("localpath")
uploadFile, _ := c.FormFile("file")
if len(localPaths) > 0 && uploadFile == nil {
s.addLocalPath(localPaths, path, res)
s.addLocalPath(ctx, localPaths, path, res)
} else if len(localPaths) == 0 && uploadFile != nil {
s.addUploadedFile(uploadFile, path, res)
s.addUploadedFile(ctx, uploadFile, path, res)
} else {
c.AbortWithError(http.StatusBadRequest, errors.New("need to specify one and only one of 'file' or 'localpath' form fields"))
return
Expand Down Expand Up @@ -90,7 +93,7 @@ func (s *Server) PostTreeHandler(c *gin.Context) {
}

// Adds files from the local filesystem, passing the path
func (s *Server) addLocalPath(paths []string, destination string, res chan<- repository.PathResultMessage) {
func (s *Server) addLocalPath(ctx context.Context, paths []string, destination string, res chan<- repository.PathResultMessage) {
// Iterate through the paths and add them all
var err error
var expanded string
Expand All @@ -108,12 +111,12 @@ func (s *Server) addLocalPath(paths []string, destination string, res chan<- rep
folder := filepath.Dir(expanded)
target := filepath.Base(expanded)

s.Repo.AddPath(folder, target, destination, res)
s.Repo.AddPath(ctx, folder, target, destination, res)
}
}

// Add a file by a stream
func (s *Server) addUploadedFile(uploadFile *multipart.FileHeader, destination string, res chan<- repository.PathResultMessage) {
func (s *Server) addUploadedFile(ctx context.Context, uploadFile *multipart.FileHeader, destination string, res chan<- repository.PathResultMessage) {
// Filename
filename := filepath.Base(uploadFile.Filename)
if filename == "" || filename == ".." || filename == "." || filename == "/" {
Expand Down Expand Up @@ -150,7 +153,7 @@ func (s *Server) addUploadedFile(uploadFile *multipart.FileHeader, destination s
}

// Add the file
result, err := s.Repo.AddStream(in, filename, destination, mime, size)
result, err := s.Repo.AddStream(ctx, in, filename, destination, mime, size)
res <- repository.PathResultMessage{
Path: destination + filename,
Status: result,
Expand Down
4 changes: 3 additions & 1 deletion server/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ import (

// FileHandler is the handler for GET /file/:fileId, which returns a (decrypted) file
func (s *Server) FileHandler(c *gin.Context) {
ctx := c.Request.Context()

// Get the fileId
fileId := c.Param("fileId")
if fileId == "" {
Expand All @@ -52,7 +54,7 @@ func (s *Server) FileHandler(c *gin.Context) {
}

// Load and decrypt the file, then pipe it to the response writer
found, _, err := s.Store.Get(fileId, c.Writer, func(metadata *crypto.Metadata) {
found, _, err := s.Store.GetWithContext(ctx, fileId, c.Writer, func(metadata *crypto.Metadata) {
// Send headers before the data is sent
if metadata.ContentType != "" {
c.Header("Content-Type", metadata.ContentType)
Expand Down
7 changes: 5 additions & 2 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"os"
"os/signal"
"syscall"
"time"

"github.com/ItalyPaleAle/prvt/infofile"

Expand Down Expand Up @@ -89,11 +90,13 @@ func (s *Server) Start(address, port string) error {
signal.Notify(s, os.Interrupt, syscall.SIGTERM)
<-s

// We received an interrupt signal, shut down.
if err := server.Shutdown(context.Background()); err != nil {
// We received an interrupt signal, shut down
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
if err := server.Shutdown(ctx); err != nil {
// Error from closing listeners, or context timeout:
fmt.Printf("HTTP server shutdown error: %v\n", err)
}
cancel()
close(idleConnsClosed)
}()

Expand Down
Loading

0 comments on commit b44064c

Please sign in to comment.