diff --git a/internal/config/config.go b/internal/config/config.go index 0861c2f7f..fadbd81e7 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -287,13 +287,19 @@ func registerClientFlags(fs *flag.FlagSet) { fs.Int( ClientGRPCMaxMessageReceiveSizeKey, DefMaxMessageRecieveSize, - "Updates the client grpc setting MaxRecvMsgSize with the specific value in MB.", + "Updates the client grpc setting MaxRecvMsgSize with the specific value in bytes.", ) fs.Int( ClientGRPCMaxMessageSendSizeKey, DefMaxMessageSendSize, - "Updates the client grpc setting MaxSendMsgSize with the specific value in MB.", + "Updates the client grpc setting MaxSendMsgSize with the specific value in bytes.", + ) + + fs.Uint32( + ClientGRPCFileChunkSizeKey, + DefFileChunkSize, + "File chunk size in bytes.", ) } @@ -611,6 +617,7 @@ func resolveClient() *Client { MaxMessageSize: viperInstance.GetInt(ClientGRPCMaxMessageSizeKey), MaxMessageReceiveSize: viperInstance.GetInt(ClientGRPCMaxMessageReceiveSizeKey), MaxMessageSendSize: viperInstance.GetInt(ClientGRPCMaxMessageSendSizeKey), + FileChunkSize: viperInstance.GetUint32(ClientGRPCFileChunkSizeKey), }, Backoff: &BackOff{ InitialInterval: viperInstance.GetDuration(ClientBackoffInitialIntervalKey), diff --git a/internal/config/config_test.go b/internal/config/config_test.go index cafa6a5aa..f93703774 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -100,6 +100,7 @@ func checkDefaultsClientValues(t *testing.T, viperInstance *viper.Viper) { assert.Equal(t, DefMaxMessageSize, viperInstance.GetInt(ClientGRPCMaxMessageSizeKey)) assert.Equal(t, DefMaxMessageRecieveSize, viperInstance.GetInt(ClientGRPCMaxMessageReceiveSizeKey)) assert.Equal(t, DefMaxMessageSendSize, viperInstance.GetInt(ClientGRPCMaxMessageSendSizeKey)) + assert.Equal(t, DefFileChunkSize, viperInstance.GetUint32(ClientGRPCFileChunkSizeKey)) assert.Equal(t, make(map[string]string), viperInstance.GetStringMapString(LabelsRootKey)) } @@ -746,6 +747,7 @@ func createConfig() *Config { MaxMessageSize: 1048575, MaxMessageReceiveSize: 1048575, MaxMessageSendSize: 1048575, + FileChunkSize: 48575, }, Backoff: &BackOff{ InitialInterval: 200 * time.Millisecond, diff --git a/internal/config/defaults.go b/internal/config/defaults.go index 4c51eb006..fb8bb943c 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -5,7 +5,6 @@ package config import ( - "math" "time" pkg "github.com/nginx/agent/v3/pkg/config" @@ -28,9 +27,10 @@ const ( DefCommandTLServerNameKey = "" // Client GRPC Settings - DefMaxMessageSize = 0 // 0 = unset - DefMaxMessageRecieveSize = 4194304 // default 4 MB - DefMaxMessageSendSize = math.MaxInt32 + DefMaxMessageSize = 0 // 0 = unset + DefMaxMessageRecieveSize = 4194304 // default 4 MB + DefMaxMessageSendSize = 4194304 // default 4 MB + DefFileChunkSize uint32 = 2097152 // 2MB // Client HTTP Settings DefHTTPTimeout = 10 * time.Second diff --git a/internal/config/flags.go b/internal/config/flags.go index c98268990..8d4f0fe21 100644 --- a/internal/config/flags.go +++ b/internal/config/flags.go @@ -36,6 +36,7 @@ var ( ClientGRPCMaxMessageSendSizeKey = pre(ClientRootKey) + "grpc_max_message_send_size" ClientGRPCMaxMessageReceiveSizeKey = pre(ClientRootKey) + "grpc_max_message_receive_size" ClientGRPCMaxMessageSizeKey = pre(ClientRootKey) + "grpc_max_message_size" + ClientGRPCFileChunkSizeKey = pre(ClientRootKey) + "grpc_file_chunk_size" ClientBackoffInitialIntervalKey = pre(ClientRootKey) + "backoff_initial_interval" ClientBackoffMaxIntervalKey = pre(ClientRootKey) + "backoff_max_interval" diff --git a/internal/config/testdata/nginx-agent.conf b/internal/config/testdata/nginx-agent.conf index 9f9a9d3d9..d631ac337 100644 --- a/internal/config/testdata/nginx-agent.conf +++ b/internal/config/testdata/nginx-agent.conf @@ -35,6 +35,7 @@ client: max_message_size: 1048575 max_message_receive_size: 1048575 max_message_send_size: 1048575 + file_chunk_size: 48575 backoff: initial_interval: 200ms max_interval: 10s diff --git a/internal/config/types.go b/internal/config/types.go index 3afb29756..0a777deed 100644 --- a/internal/config/types.go +++ b/internal/config/types.go @@ -80,12 +80,11 @@ type ( } GRPC struct { - KeepAlive *KeepAlive `yaml:"-" mapstructure:"target"` - // if MaxMessageSize is size set then we use that value, - // otherwise MaxMessageRecieveSize and MaxMessageSendSize for individual settings - MaxMessageSize int `yaml:"-" mapstructure:"max_message_size"` - MaxMessageReceiveSize int `yaml:"-" mapstructure:"max_message_receive_size"` - MaxMessageSendSize int `yaml:"-" mapstructure:"max_message_send_size"` + KeepAlive *KeepAlive `yaml:"-" mapstructure:"target"` + MaxMessageSize int `yaml:"-" mapstructure:"max_message_size"` + MaxMessageReceiveSize int `yaml:"-" mapstructure:"max_message_receive_size"` + MaxMessageSendSize int `yaml:"-" mapstructure:"max_message_send_size"` + FileChunkSize uint32 `yaml:"file_chunk_size" mapstructure:"file_chunk_size"` } KeepAlive struct { diff --git a/internal/file/file_manager_service.go b/internal/file/file_manager_service.go index 2aaf02128..33db307e8 100644 --- a/internal/file/file_manager_service.go +++ b/internal/file/file_manager_service.go @@ -6,12 +6,15 @@ package file import ( + "bufio" "context" "encoding/json" "errors" "fmt" + "io" "log/slog" "maps" + "math" "os" "slices" "sync" @@ -30,6 +33,7 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" backoffHelpers "github.com/nginx/agent/v3/internal/backoff" + grpc2 "google.golang.org/grpc" ) //go:generate go run github.com/maxbrunsfeld/counterfeiter/v6@v6.8.1 -generate @@ -52,6 +56,7 @@ var ( type ( fileOperator interface { Write(ctx context.Context, fileContent []byte, file *mpi.FileMeta) error + CreateFileDirectories(ctx context.Context, fileMeta *mpi.FileMeta, filePermission os.FileMode) error } fileManagerServiceInterface interface { @@ -226,23 +231,40 @@ func (fms *FileManagerService) UpdateFile( fileToUpdate *mpi.File, ) error { slog.InfoContext(ctx, "Updating file", "instance_id", instanceID, "file_name", fileToUpdate.GetFileMeta().GetName()) + + slog.DebugContext(ctx, "Checking file size", + "file_size", fileToUpdate.GetFileMeta().GetSize(), + "max_message_send_size", int64(fms.agentConfig.Client.Grpc.MaxMessageSendSize), + ) + + if fileToUpdate.GetFileMeta().GetSize() <= int64(fms.agentConfig.Client.Grpc.MaxMessageSendSize) { + return fms.sendUpdateFileRequest(ctx, fileToUpdate) + } + + return fms.sendUpdateFileStream(ctx, fileToUpdate, fms.agentConfig.Client.Grpc.FileChunkSize) +} + +func (fms *FileManagerService) sendUpdateFileRequest( + ctx context.Context, + fileToUpdate *mpi.File, +) error { + messageMeta := &mpi.MessageMeta{ + MessageId: id.GenerateMessageID(), + CorrelationId: logger.GetCorrelationID(ctx), + Timestamp: timestamppb.Now(), + } + contents, err := os.ReadFile(fileToUpdate.GetFileMeta().GetName()) if err != nil { return err } - correlationID := logger.GetCorrelationID(ctx) - request := &mpi.UpdateFileRequest{ File: fileToUpdate, Contents: &mpi.FileContents{ Contents: contents, }, - MessageMeta: &mpi.MessageMeta{ - MessageId: id.GenerateMessageID(), - CorrelationId: correlationID, - Timestamp: timestamppb.Now(), - }, + MessageMeta: messageMeta, } backOffCtx, backoffCancel := context.WithTimeout(ctx, fms.agentConfig.Client.Backoff.MaxElapsedTime) @@ -272,15 +294,215 @@ func (fms *FileManagerService) UpdateFile( return response, nil } - response, err := backoff.RetryWithData(sendUpdateFile, backoffHelpers.Context(backOffCtx, - fms.agentConfig.Client.Backoff)) + response, err := backoff.RetryWithData( + sendUpdateFile, + backoffHelpers.Context(backOffCtx, fms.agentConfig.Client.Backoff), + ) if err != nil { return err } slog.DebugContext(ctx, "UpdateFile response", "response", response) - return err + return nil +} + +func (fms *FileManagerService) sendUpdateFileStream( + ctx context.Context, + fileToUpdate *mpi.File, + chunkSize uint32, +) error { + if chunkSize == 0 { + return fmt.Errorf("file chunk size must be greater than zero") + } + + updateFileStreamClient, err := fms.fileServiceClient.UpdateFileStream(ctx) + if err != nil { + return err + } + + err = fms.sendUpdateFileStreamHeader(ctx, fileToUpdate, chunkSize, updateFileStreamClient) + if err != nil { + return err + } + + return fms.sendFileUpdateStreamChunks(ctx, fileToUpdate, chunkSize, updateFileStreamClient) +} + +func (fms *FileManagerService) sendUpdateFileStreamHeader( + ctx context.Context, + fileToUpdate *mpi.File, + chunkSize uint32, + updateFileStreamClient grpc2.ClientStreamingClient[mpi.FileDataChunk, mpi.UpdateFileResponse], +) error { + messageMeta := &mpi.MessageMeta{ + MessageId: id.GenerateMessageID(), + CorrelationId: logger.GetCorrelationID(ctx), + Timestamp: timestamppb.Now(), + } + + numberOfChunks := uint32(math.Ceil(float64(fileToUpdate.GetFileMeta().GetSize()) / float64(chunkSize))) + + header := mpi.FileDataChunk_Header{ + Header: &mpi.FileDataChunkHeader{ + FileMeta: fileToUpdate.GetFileMeta(), + Chunks: numberOfChunks, + ChunkSize: chunkSize, + }, + } + + backOffCtx, backoffCancel := context.WithTimeout(ctx, fms.agentConfig.Client.Backoff.MaxElapsedTime) + defer backoffCancel() + + sendUpdateFileHeader := func() error { + slog.DebugContext(ctx, "Sending update file stream header", "header", header) + if fms.fileServiceClient == nil { + return errors.New("file service client is not initialized") + } + + if !fms.isConnected.Load() { + return errors.New("CreateConnection rpc has not being called yet") + } + + err := updateFileStreamClient.Send( + &mpi.FileDataChunk{ + Meta: messageMeta, + Chunk: &header, + }, + ) + + validatedError := grpc.ValidateGrpcError(err) + + if validatedError != nil { + slog.ErrorContext(ctx, "Failed to send update file stream header", "error", validatedError) + + return validatedError + } + + return nil + } + + return backoff.Retry(sendUpdateFileHeader, backoffHelpers.Context(backOffCtx, fms.agentConfig.Client.Backoff)) +} + +func (fms *FileManagerService) sendFileUpdateStreamChunks( + ctx context.Context, + fileToUpdate *mpi.File, + chunkSize uint32, + updateFileStreamClient grpc2.ClientStreamingClient[mpi.FileDataChunk, mpi.UpdateFileResponse], +) error { + f, err := os.Open(fileToUpdate.GetFileMeta().GetName()) + defer func() { + closeError := f.Close() + if closeError != nil { + slog.WarnContext( + ctx, "Failed to close file", + "file", fileToUpdate.GetFileMeta().GetName(), + "error", closeError, + ) + } + }() + if err != nil { + return err + } + + var chunkID uint32 + + reader := bufio.NewReader(f) + for { + chunk, readChunkError := fms.readChunk(ctx, chunkSize, reader, chunkID) + if readChunkError != nil { + return readChunkError + } + if chunk.Content == nil { + break + } + + sendError := fms.sendFileUpdateStreamChunk(ctx, chunk, updateFileStreamClient) + if sendError != nil { + return sendError + } + + chunkID++ + } + + return nil +} + +func (fms *FileManagerService) readChunk( + ctx context.Context, + chunkSize uint32, + reader *bufio.Reader, + chunkID uint32, +) (mpi.FileDataChunk_Content, error) { + buf := make([]byte, chunkSize) + n, err := reader.Read(buf) + buf = buf[:n] + if err != nil { + if err != io.EOF { + return mpi.FileDataChunk_Content{}, fmt.Errorf("failed to read chunk: %w", err) + } + + slog.DebugContext(ctx, "No more data to read from file") + + return mpi.FileDataChunk_Content{}, nil + } + + slog.DebugContext(ctx, "Read file chunk", "chunk_id", chunkID, "chunk_size", len(buf)) + + chunk := mpi.FileDataChunk_Content{ + Content: &mpi.FileDataChunkContent{ + ChunkId: chunkID, + Data: buf, + }, + } + + return chunk, err +} + +func (fms *FileManagerService) sendFileUpdateStreamChunk( + ctx context.Context, + chunk mpi.FileDataChunk_Content, + updateFileStreamClient grpc2.ClientStreamingClient[mpi.FileDataChunk, mpi.UpdateFileResponse], +) error { + messageMeta := &mpi.MessageMeta{ + MessageId: id.GenerateMessageID(), + CorrelationId: logger.GetCorrelationID(ctx), + Timestamp: timestamppb.Now(), + } + + backOffCtx, backoffCancel := context.WithTimeout(ctx, fms.agentConfig.Client.Backoff.MaxElapsedTime) + defer backoffCancel() + + sendUpdateFileChunk := func() error { + slog.DebugContext(ctx, "Sending update file stream chunk", "chunk_id", chunk.Content.GetChunkId()) + if fms.fileServiceClient == nil { + return errors.New("file service client is not initialized") + } + + if !fms.isConnected.Load() { + return errors.New("CreateConnection rpc has not being called yet") + } + + err := updateFileStreamClient.Send( + &mpi.FileDataChunk{ + Meta: messageMeta, + Chunk: &chunk, + }, + ) + + validatedError := grpc.ValidateGrpcError(err) + + if validatedError != nil { + slog.ErrorContext(ctx, "Failed to send update file stream chunk", "error", validatedError) + + return validatedError + } + + return nil + } + + return backoff.Retry(sendUpdateFileChunk, backoffHelpers.Context(backOffCtx, fms.agentConfig.Client.Backoff)) } func (fms *FileManagerService) IsConnected() bool { @@ -410,6 +632,17 @@ func (fms *FileManagerService) executeFileActions(ctx context.Context) error { } func (fms *FileManagerService) fileUpdate(ctx context.Context, file *mpi.File) error { + slog.DebugContext(ctx, "Updating file", "file", file.GetFileMeta().GetName()) + if file.GetFileMeta().GetSize() <= int64(fms.agentConfig.Client.Grpc.MaxMessageReceiveSize) { + return fms.getFile(ctx, file) + } + + return fms.getChunkedFile(ctx, file) +} + +func (fms *FileManagerService) getFile(ctx context.Context, file *mpi.File) error { + slog.DebugContext(ctx, "Getting file", "file", file.GetFileMeta().GetName()) + backOffCtx, backoffCancel := context.WithTimeout(ctx, fms.agentConfig.Client.Backoff.MaxElapsedTime) defer backoffCancel() @@ -438,9 +671,82 @@ func (fms *FileManagerService) fileUpdate(ctx context.Context, file *mpi.File) e return writeErr } - validateErr := fms.validateFileHash(file.GetFileMeta().GetName()) + return fms.validateFileHash(file.GetFileMeta().GetName()) +} + +func (fms *FileManagerService) getChunkedFile(ctx context.Context, file *mpi.File) error { + slog.DebugContext(ctx, "Getting chunked file", "file", file.GetFileMeta().GetName()) + + stream, err := fms.fileServiceClient.GetFileStream(ctx, &mpi.GetFileRequest{ + MessageMeta: &mpi.MessageMeta{ + MessageId: id.GenerateMessageID(), + CorrelationId: logger.GetCorrelationID(ctx), + Timestamp: timestamppb.Now(), + }, + FileMeta: file.GetFileMeta(), + }) + if err != nil { + return fmt.Errorf("error getting file stream for %s: %w", file.GetFileMeta().GetName(), err) + } + + // Get header chunk first + headerChunk, recvHeaderChunkError := stream.Recv() + if recvHeaderChunkError != nil { + return recvHeaderChunkError + } + + slog.DebugContext(ctx, "File header chunk received", "header_chunk", headerChunk) + + header := headerChunk.GetHeader() + + writeChunkedFileError := fms.writeChunkedFile(ctx, file, header, stream) + if writeChunkedFileError != nil { + return writeChunkedFileError + } + + return nil +} + +func (fms *FileManagerService) writeChunkedFile( + ctx context.Context, + file *mpi.File, + header *mpi.FileDataChunkHeader, + stream grpc2.ServerStreamingClient[mpi.FileDataChunk], +) error { + filePermissions := files.FileMode(file.GetFileMeta().GetPermissions()) + createFileDirectoriesError := fms.fileOperator.CreateFileDirectories(ctx, file.GetFileMeta(), filePermissions) + if createFileDirectoriesError != nil { + return createFileDirectoriesError + } + + fileToWrite, createError := os.Create(file.GetFileMeta().GetName()) + defer func() { + closeError := fileToWrite.Close() + if closeError != nil { + slog.WarnContext( + ctx, "Failed to close file", + "file", file.GetFileMeta().GetName(), + "error", closeError, + ) + } + }() + if createError != nil { + return createError + } + + for i := uint32(0); i < header.GetChunks(); i++ { + chunk, recvError := stream.Recv() + if recvError != nil { + return recvError + } - return validateErr + _, chunkWriteError := fileToWrite.Write(chunk.GetContent().GetData()) + if chunkWriteError != nil { + return fmt.Errorf("error writing chunk to file %s: %w", file.GetFileMeta().GetName(), chunkWriteError) + } + } + + return nil } func (fms *FileManagerService) validateFileHash(filePath string) error { diff --git a/internal/file/file_manager_service_test.go b/internal/file/file_manager_service_test.go index 95e138c2a..96d63786f 100644 --- a/internal/file/file_manager_service_test.go +++ b/internal/file/file_manager_service_test.go @@ -10,8 +10,11 @@ import ( "fmt" "os" "path/filepath" + "sync/atomic" "testing" + "google.golang.org/grpc/metadata" + "github.com/nginx/agent/v3/internal/model" "github.com/nginx/agent/v3/pkg/files" @@ -26,6 +29,107 @@ import ( "github.com/stretchr/testify/require" ) +type FakeClientStreamingClient struct { + sendCount atomic.Int32 +} + +func (f *FakeClientStreamingClient) Send(req *mpi.FileDataChunk) error { + f.sendCount.Add(1) + return nil +} + +func (f *FakeClientStreamingClient) CloseAndRecv() (*mpi.UpdateFileResponse, error) { + return &mpi.UpdateFileResponse{}, nil +} + +func (f *FakeClientStreamingClient) Header() (metadata.MD, error) { + return metadata.MD{}, nil +} + +func (f *FakeClientStreamingClient) Trailer() metadata.MD { + return nil +} + +func (f *FakeClientStreamingClient) CloseSend() error { + return nil +} + +func (f *FakeClientStreamingClient) Context() context.Context { + return context.Background() +} + +func (f *FakeClientStreamingClient) SendMsg(m any) error { + return nil +} + +func (f *FakeClientStreamingClient) RecvMsg(m any) error { + return nil +} + +type FakeServerStreamingClient struct { + chunks map[uint32][]byte + fileName string + currentChunkID uint32 +} + +func (f *FakeServerStreamingClient) Recv() (*mpi.FileDataChunk, error) { + fileDataChunk := &mpi.FileDataChunk{ + Meta: &mpi.MessageMeta{ + MessageId: "123", + CorrelationId: "1234", + Timestamp: timestamppb.Now(), + }, + } + + if f.currentChunkID == 0 { + fileDataChunk.Chunk = &mpi.FileDataChunk_Header{ + Header: &mpi.FileDataChunkHeader{ + FileMeta: &mpi.FileMeta{ + Name: f.fileName, + Permissions: "666", + }, + Chunks: 52, + ChunkSize: 1, + }, + } + } else { + fileDataChunk.Chunk = &mpi.FileDataChunk_Content{ + Content: &mpi.FileDataChunkContent{ + ChunkId: f.currentChunkID, + Data: f.chunks[f.currentChunkID-1], + }, + } + } + + f.currentChunkID++ + + return fileDataChunk, nil +} + +func (f *FakeServerStreamingClient) Header() (metadata.MD, error) { + return metadata.MD{}, nil +} + +func (f *FakeServerStreamingClient) Trailer() metadata.MD { + return metadata.MD{} +} + +func (f *FakeServerStreamingClient) CloseSend() error { + return nil +} + +func (f *FakeServerStreamingClient) Context() context.Context { + return context.Background() +} + +func (f *FakeServerStreamingClient) SendMsg(m any) error { + return nil +} + +func (f *FakeServerStreamingClient) RecvMsg(m any) error { + return nil +} + func TestFileManagerService_UpdateOverview(t *testing.T) { ctx := context.Background() @@ -142,6 +246,30 @@ func TestFileManagerService_UpdateFile(t *testing.T) { } } +func TestFileManagerService_UpdateFile_LargeFile(t *testing.T) { + ctx := context.Background() + tempDir := os.TempDir() + + testFile := helpers.CreateFileWithErrorCheck(t, tempDir, "nginx.conf") + writeFileError := os.WriteFile(testFile.Name(), []byte("#test content"), 0o600) + require.NoError(t, writeFileError) + fileMeta := protos.FileMetaLargeFile(testFile.Name(), "") + + fakeFileServiceClient := &v1fakes.FakeFileServiceClient{} + fakeClientStreamingClient := &FakeClientStreamingClient{sendCount: atomic.Int32{}} + fakeFileServiceClient.UpdateFileStreamReturns(fakeClientStreamingClient, nil) + fileManagerService := NewFileManagerService(fakeFileServiceClient, types.AgentConfig()) + fileManagerService.SetIsConnected(true) + + err := fileManagerService.UpdateFile(ctx, "123", &mpi.File{FileMeta: fileMeta}) + + require.NoError(t, err) + assert.Equal(t, 0, fakeFileServiceClient.UpdateFileCallCount()) + assert.Equal(t, 14, int(fakeClientStreamingClient.sendCount.Load())) + + helpers.RemoveFileWithErrorCheck(t, testFile.Name()) +} + func TestFileManagerService_ConfigApply_Add(t *testing.T) { ctx := context.Background() tempDir := t.TempDir() @@ -178,6 +306,59 @@ func TestFileManagerService_ConfigApply_Add(t *testing.T) { require.NoError(t, readErr) assert.Equal(t, fileContent, data) assert.Equal(t, fileManagerService.fileActions[filePath], overview.GetFiles()[0]) + assert.Equal(t, 1, fakeFileServiceClient.GetFileCallCount()) +} + +func TestFileManagerService_ConfigApply_Add_LargeFile(t *testing.T) { + ctx := context.Background() + tempDir := t.TempDir() + + filePath := filepath.Join(tempDir, "nginx.conf") + fileContent := []byte("location /test {\n return 200 \"Test location\\n\";\n}") + fileHash := files.GenerateHash(fileContent) + defer helpers.RemoveFileWithErrorCheck(t, filePath) + + overview := protos.FileOverviewLargeFile(filePath, fileHash) + + manifestDirPath = tempDir + manifestFilePath = manifestDirPath + "/manifest.json" + helpers.CreateFileWithErrorCheck(t, manifestDirPath, "manifest.json") + + fakeFileServiceClient := &v1fakes.FakeFileServiceClient{} + fakeFileServiceClient.GetOverviewReturns(&mpi.GetOverviewResponse{ + Overview: overview, + }, nil) + + fakeServerStreamingClient := &FakeServerStreamingClient{ + chunks: make(map[uint32][]byte), + currentChunkID: 0, + fileName: filePath, + } + + t.Logf("fakeServerStreamingClient: %v", fakeServerStreamingClient) + + for i := 0; i < len(fileContent); i++ { + fakeServerStreamingClient.chunks[uint32(i)] = []byte{fileContent[i]} + } + + t.Logf("fakeServerStreamingClient: %v", fakeServerStreamingClient) + + fakeFileServiceClient.GetFileStreamReturns(fakeServerStreamingClient, nil) + agentConfig := types.AgentConfig() + agentConfig.AllowedDirectories = []string{tempDir} + fileManagerService := NewFileManagerService(fakeFileServiceClient, agentConfig) + + request := protos.CreateConfigApplyRequest(overview) + t.Logf("request: %v", request) + writeStatus, err := fileManagerService.ConfigApply(ctx, request) + require.NoError(t, err) + assert.Equal(t, model.OK, writeStatus) + data, readErr := os.ReadFile(filePath) + require.NoError(t, readErr) + assert.Equal(t, fileContent, data) + assert.Equal(t, fileManagerService.fileActions[filePath], overview.GetFiles()[0]) + assert.Equal(t, 0, fakeFileServiceClient.GetFileCallCount()) + assert.Equal(t, 53, int(fakeServerStreamingClient.currentChunkID)) } func TestFileManagerService_ConfigApply_Update(t *testing.T) { diff --git a/internal/file/file_operator.go b/internal/file/file_operator.go index 06c174b90..171d40947 100644 --- a/internal/file/file_operator.go +++ b/internal/file/file_operator.go @@ -29,19 +29,36 @@ func NewFileOperator() *FileOperator { func (fo *FileOperator) Write(ctx context.Context, fileContent []byte, file *mpi.FileMeta) error { filePermission := files.FileMode(file.GetPermissions()) - if _, err := os.Stat(file.GetName()); os.IsNotExist(err) { - slog.DebugContext(ctx, "File does not exist, creating new file", "file_path", file.GetName()) - err = os.MkdirAll(path.Dir(file.GetName()), filePermission) - if err != nil { - return fmt.Errorf("error creating directory %s: %w", path.Dir(file.GetName()), err) - } + err := fo.CreateFileDirectories(ctx, file, filePermission) + if err != nil { + return err } - err := os.WriteFile(file.GetName(), fileContent, filePermission) - if err != nil { - return fmt.Errorf("error writing to file %s: %w", file.GetName(), err) + writeErr := os.WriteFile(file.GetName(), fileContent, filePermission) + if writeErr != nil { + return fmt.Errorf("error writing to file %s: %w", file.GetName(), writeErr) } slog.DebugContext(ctx, "Content written to file", "file_path", file.GetName()) return nil } + +func (fo *FileOperator) CreateFileDirectories( + ctx context.Context, + fileMeta *mpi.FileMeta, + filePermission os.FileMode, +) error { + if _, err := os.Stat(fileMeta.GetName()); os.IsNotExist(err) { + parentDirectory := path.Dir(fileMeta.GetName()) + slog.DebugContext( + ctx, "File does not exist, creating parent directory", + "directory_path", parentDirectory, + ) + err = os.MkdirAll(parentDirectory, filePermission) + if err != nil { + return fmt.Errorf("error creating directory %s: %w", parentDirectory, err) + } + } + + return nil +} diff --git a/internal/file/filefakes/fake_file_operator.go b/internal/file/filefakes/fake_file_operator.go index 6373caf06..946638b6a 100644 --- a/internal/file/filefakes/fake_file_operator.go +++ b/internal/file/filefakes/fake_file_operator.go @@ -3,12 +3,26 @@ package filefakes import ( "context" + "io/fs" "sync" v1 "github.com/nginx/agent/v3/api/grpc/mpi/v1" ) type FakeFileOperator struct { + CreateFileDirectoriesStub func(context.Context, *v1.FileMeta, fs.FileMode) error + createFileDirectoriesMutex sync.RWMutex + createFileDirectoriesArgsForCall []struct { + arg1 context.Context + arg2 *v1.FileMeta + arg3 fs.FileMode + } + createFileDirectoriesReturns struct { + result1 error + } + createFileDirectoriesReturnsOnCall map[int]struct { + result1 error + } WriteStub func(context.Context, []byte, *v1.FileMeta) error writeMutex sync.RWMutex writeArgsForCall []struct { @@ -26,6 +40,69 @@ type FakeFileOperator struct { invocationsMutex sync.RWMutex } +func (fake *FakeFileOperator) CreateFileDirectories(arg1 context.Context, arg2 *v1.FileMeta, arg3 fs.FileMode) error { + fake.createFileDirectoriesMutex.Lock() + ret, specificReturn := fake.createFileDirectoriesReturnsOnCall[len(fake.createFileDirectoriesArgsForCall)] + fake.createFileDirectoriesArgsForCall = append(fake.createFileDirectoriesArgsForCall, struct { + arg1 context.Context + arg2 *v1.FileMeta + arg3 fs.FileMode + }{arg1, arg2, arg3}) + stub := fake.CreateFileDirectoriesStub + fakeReturns := fake.createFileDirectoriesReturns + fake.recordInvocation("CreateFileDirectories", []interface{}{arg1, arg2, arg3}) + fake.createFileDirectoriesMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeFileOperator) CreateFileDirectoriesCallCount() int { + fake.createFileDirectoriesMutex.RLock() + defer fake.createFileDirectoriesMutex.RUnlock() + return len(fake.createFileDirectoriesArgsForCall) +} + +func (fake *FakeFileOperator) CreateFileDirectoriesCalls(stub func(context.Context, *v1.FileMeta, fs.FileMode) error) { + fake.createFileDirectoriesMutex.Lock() + defer fake.createFileDirectoriesMutex.Unlock() + fake.CreateFileDirectoriesStub = stub +} + +func (fake *FakeFileOperator) CreateFileDirectoriesArgsForCall(i int) (context.Context, *v1.FileMeta, fs.FileMode) { + fake.createFileDirectoriesMutex.RLock() + defer fake.createFileDirectoriesMutex.RUnlock() + argsForCall := fake.createFileDirectoriesArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeFileOperator) CreateFileDirectoriesReturns(result1 error) { + fake.createFileDirectoriesMutex.Lock() + defer fake.createFileDirectoriesMutex.Unlock() + fake.CreateFileDirectoriesStub = nil + fake.createFileDirectoriesReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeFileOperator) CreateFileDirectoriesReturnsOnCall(i int, result1 error) { + fake.createFileDirectoriesMutex.Lock() + defer fake.createFileDirectoriesMutex.Unlock() + fake.CreateFileDirectoriesStub = nil + if fake.createFileDirectoriesReturnsOnCall == nil { + fake.createFileDirectoriesReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.createFileDirectoriesReturnsOnCall[i] = struct { + result1 error + }{result1} +} + func (fake *FakeFileOperator) Write(arg1 context.Context, arg2 []byte, arg3 *v1.FileMeta) error { var arg2Copy []byte if arg2 != nil { @@ -97,6 +174,8 @@ func (fake *FakeFileOperator) WriteReturnsOnCall(i int, result1 error) { func (fake *FakeFileOperator) Invocations() map[string][][]interface{} { fake.invocationsMutex.RLock() defer fake.invocationsMutex.RUnlock() + fake.createFileDirectoriesMutex.RLock() + defer fake.createFileDirectoriesMutex.RUnlock() fake.writeMutex.RLock() defer fake.writeMutex.RUnlock() copiedInvocations := map[string][][]interface{}{} diff --git a/test/protos/files.go b/test/protos/files.go index 36b0a50e5..3dbee464b 100644 --- a/test/protos/files.go +++ b/test/protos/files.go @@ -7,9 +7,12 @@ package protos import ( mpi "github.com/nginx/agent/v3/api/grpc/mpi/v1" + "github.com/nginx/agent/v3/internal/config" "google.golang.org/protobuf/types/known/timestamppb" ) +const largeFileSize = int64(10 * config.DefFileChunkSize) + func FileMeta(fileName, fileHash string) *mpi.FileMeta { lastModified, _ := CreateProtoTime("2024-01-09T13:22:21Z") @@ -21,6 +24,18 @@ func FileMeta(fileName, fileHash string) *mpi.FileMeta { } } +func FileMetaLargeFile(fileName, fileHash string) *mpi.FileMeta { + lastModified, _ := CreateProtoTime("2024-01-09T13:22:21Z") + + return &mpi.FileMeta{ + ModifiedTime: lastModified, + Name: fileName, + Hash: fileHash, + Permissions: "0600", + Size: largeFileSize, + } +} + func ManifestFileMeta(fileName, fileHash string) *mpi.FileMeta { return &mpi.FileMeta{ ModifiedTime: nil, @@ -83,6 +98,23 @@ func FileOverview(filePath, fileHash string) *mpi.FileOverview { } } +func FileOverviewLargeFile(filePath, fileHash string) *mpi.FileOverview { + return &mpi.FileOverview{ + Files: []*mpi.File{ + { + FileMeta: &mpi.FileMeta{ + Name: filePath, + Hash: fileHash, + ModifiedTime: timestamppb.Now(), + Permissions: "0640", + Size: largeFileSize, + }, + }, + }, + ConfigVersion: CreateConfigVersion(), + } +} + func FileContents(content []byte) *mpi.FileContents { return &mpi.FileContents{ Contents: content, diff --git a/test/types/config.go b/test/types/config.go index 74d71f2de..34730dcb0 100644 --- a/test/types/config.go +++ b/test/types/config.go @@ -48,6 +48,9 @@ func AgentConfig() *config.Config { Time: clientTime, PermitWithoutStream: clientPermitWithoutStream, }, + MaxMessageReceiveSize: 1, + MaxMessageSendSize: 1, + FileChunkSize: 1, }, Backoff: &config.BackOff{ InitialInterval: commonInitialInterval,