diff --git a/pkg/cmd/server/options.go b/pkg/cmd/server/options.go index 0945c81c..06788c19 100644 --- a/pkg/cmd/server/options.go +++ b/pkg/cmd/server/options.go @@ -8,16 +8,18 @@ import ( func NewServerOptions() *ServerOptions { return &ServerOptions{ - Mode: DefaultMode, - Port: DefaultPort, - AuthEnabled: false, - AuthWhitelist: []string{}, - AuthKeyType: DefaultAuthKeyType, - Database: DatabaseOptions{}, - DefaultBackend: DefaultBackendOptions{}, - DefaultSource: DefaultSourceOptions{}, - MaxConcurrent: constant.MaxConcurrent, - LogFilePath: constant.DefaultLogFilePath, + Mode: DefaultMode, + Port: DefaultPort, + AuthEnabled: false, + AuthWhitelist: []string{}, + AuthKeyType: DefaultAuthKeyType, + Database: DatabaseOptions{}, + DefaultBackend: DefaultBackendOptions{}, + DefaultSource: DefaultSourceOptions{}, + MaxConcurrent: constant.MaxConcurrent, + MaxAsyncConcurrent: constant.MaxAsyncConcurrent, + MaxAsyncBuffer: constant.MaxAsyncBuffer, + LogFilePath: constant.DefaultLogFilePath, } } @@ -37,6 +39,8 @@ func (o *ServerOptions) Config() (*server.Config, error) { cfg.AuthWhitelist = o.AuthWhitelist cfg.AuthKeyType = o.AuthKeyType cfg.MaxConcurrent = o.MaxConcurrent + cfg.MaxAsyncConcurrent = o.MaxAsyncConcurrent + cfg.MaxAsyncBuffer = o.MaxAsyncBuffer cfg.LogFilePath = o.LogFilePath return cfg, nil } diff --git a/pkg/cmd/server/server.go b/pkg/cmd/server/server.go index ed2d46b8..20873e6d 100644 --- a/pkg/cmd/server/server.go +++ b/pkg/cmd/server/server.go @@ -51,6 +51,10 @@ func (o *ServerOptions) AddServerFlags(cmd *cobra.Command) { i18n.T("Specify the auth key type. Default to RSA")) cmd.Flags().IntVarP(&o.MaxConcurrent, "max-concurrent", "", 10, i18n.T("Maximum number of concurrent executions including preview, apply and destroy. Default to 10.")) + cmd.Flags().IntVarP(&o.MaxAsyncBuffer, "max-async-buffer", "", 100, + i18n.T("Maximum number of buffer zones during concurrent async executions including generate, preview, apply and destroy. Default to 100.")) + cmd.Flags().IntVarP(&o.MaxAsyncConcurrent, "max-async-concurrent", "", 10, + i18n.T("Maximum number of concurrent async executions including generate, preview, apply and destroy. Default to 10.")) cmd.Flags().StringVarP(&o.LogFilePath, "log-file-path", "", constant.DefaultLogFilePath, i18n.T("File path to write logs to. Default to /home/admin/logs/kusion.log")) o.Database.AddFlags(cmd.Flags()) diff --git a/pkg/cmd/server/types.go b/pkg/cmd/server/types.go index 0eed2af6..677a7cdb 100644 --- a/pkg/cmd/server/types.go +++ b/pkg/cmd/server/types.go @@ -5,16 +5,18 @@ import ( ) type ServerOptions struct { - Mode string - Port int - AuthEnabled bool - AuthWhitelist []string - AuthKeyType string - Database DatabaseOptions - DefaultBackend DefaultBackendOptions - DefaultSource DefaultSourceOptions - MaxConcurrent int - LogFilePath string + Mode string + Port int + AuthEnabled bool + AuthWhitelist []string + AuthKeyType string + Database DatabaseOptions + DefaultBackend DefaultBackendOptions + DefaultSource DefaultSourceOptions + MaxConcurrent int + MaxAsyncConcurrent int + MaxAsyncBuffer int + LogFilePath string } type Options interface { diff --git a/pkg/domain/constant/global.go b/pkg/domain/constant/global.go index 2f0d8961..54b69e44 100644 --- a/pkg/domain/constant/global.go +++ b/pkg/domain/constant/global.go @@ -1,14 +1,21 @@ package constant +import "time" + // These constants represent the possible states of a stack. const ( - DefaultUser = "test.user" - DefaultWorkspace = "default" - DefaultBackend = "default" - DefaultOrgOwner = "kusion" - DefaultSourceType = SourceProviderTypeGit - DefaultSourceDesc = "Default source" - DefaultSystemName = "kusion" - MaxConcurrent = 10 - DefaultLogFilePath = "/home/admin/logs/kusion.log" + DefaultUser = "test.user" + DefaultWorkspace = "default" + DefaultBackend = "default" + DefaultOrgOwner = "kusion" + DefaultSourceType = SourceProviderTypeGit + DefaultSourceDesc = "Default source" + DefaultSystemName = "kusion" + DefaultReleaseNamespace = "server" + MaxConcurrent = 10 + MaxAsyncConcurrent = 1 + MaxAsyncBuffer = 100 + DefaultLogFilePath = "/home/admin/logs/kusion.log" + RepoCacheTTL = 60 * time.Minute + RunTimeOut = 60 * time.Minute ) diff --git a/pkg/domain/constant/run.go b/pkg/domain/constant/run.go new file mode 100644 index 00000000..ec3e5a7b --- /dev/null +++ b/pkg/domain/constant/run.go @@ -0,0 +1,61 @@ +package constant + +import ( + "fmt" +) + +type ( + RunType string + RunStatus string +) + +const ( + RunTypeGenerate RunType = "Generate" + RunTypePreview RunType = "Preview" + RunTypeApply RunType = "Apply" + RunTypeDestroy RunType = "Destroy" + RunStatusScheduling RunStatus = "Scheduling" + RunStatusInProgress RunStatus = "InProgress" + RunStatusFailed RunStatus = "Failed" + RunStatusSucceeded RunStatus = "Succeeded" + RunStatusCancelled RunStatus = "Cancelled" + RunStatusQueued RunStatus = "Queued" +) + +// ParseRunType parses a string into a RunType. +// If the string is not a valid RunType, it returns an error. +func ParseRunType(s string) (RunType, error) { + switch s { + case string(RunTypeGenerate): + return RunTypeGenerate, nil + case string(RunTypePreview): + return RunTypePreview, nil + case string(RunTypeApply): + return RunTypeApply, nil + case string(RunTypeDestroy): + return RunTypeDestroy, nil + default: + return RunType(""), fmt.Errorf("invalid RunType: %q", s) + } +} + +// ParseRunStatus parses a string into a RunStatus. +// If the string is not a valid RunStatus, it returns an error. +func ParseRunStatus(s string) (RunStatus, error) { + switch s { + case string(RunStatusScheduling): + return RunStatusScheduling, nil + case string(RunStatusInProgress): + return RunStatusInProgress, nil + case string(RunStatusFailed): + return RunStatusFailed, nil + case string(RunStatusSucceeded): + return RunStatusSucceeded, nil + case string(RunStatusCancelled): + return RunStatusCancelled, nil + case string(RunStatusQueued): + return RunStatusQueued, nil + default: + return RunStatus(""), fmt.Errorf("invalid RunType: %q", s) + } +} diff --git a/pkg/domain/entity/run.go b/pkg/domain/entity/run.go new file mode 100644 index 00000000..5728534a --- /dev/null +++ b/pkg/domain/entity/run.go @@ -0,0 +1,72 @@ +package entity + +import ( + "fmt" + "time" + + "kusionstack.io/kusion/pkg/domain/constant" +) + +// Run represents the specific run, including type +// which should be a specific instance of the run provider. +type Run struct { + // ID is the id of the run. + ID uint `yaml:"id" json:"id"` + // RunType is the type of the run provider. + Type constant.RunType `yaml:"type" json:"type"` + // Stack is the stack of the run. + Stack *Stack `yaml:"stack" json:"stack"` + // Workspace is the target workspace of the run. + Workspace string `yaml:"workspace" json:"workspace"` + // Status is the status of the run. + Status constant.RunStatus `yaml:"status" json:"status"` + // Result is the result of the run. + Result string `yaml:"result" json:"result"` + // Result RunResult `yaml:"result" json:"result"` + // Logs is the logs of the run. + Logs string `yaml:"logs" json:"logs"` + // CreationTimestamp is the timestamp of the created for the run. + CreationTimestamp time.Time `yaml:"creationTimestamp,omitempty" json:"creationTimestamp,omitempty"` + // UpdateTimestamp is the timestamp of the updated for the run. + UpdateTimestamp time.Time `yaml:"updateTimestamp,omitempty" json:"updateTimestamp,omitempty"` +} + +// RunResult represents the result of the run. +type RunResult struct { + // ExitCode is the exit code of the run. + ExitCode int `yaml:"exitCode" json:"exitCode"` + // Message is the message of the run. + Message string `yaml:"message" json:"message"` + // Old is the old state of the run. + Old string `yaml:"old" json:"old"` + // New is the new state of the run. + New string `yaml:"new" json:"new"` +} + +type RunFilter struct { + ProjectID uint + StackID uint + Workspace string +} + +// Validate checks if the run is valid. +// It returns an error if the run is not valid. +func (r *Run) Validate() error { + if r == nil { + return fmt.Errorf("run is nil") + } + + if r.Type == "" { + return fmt.Errorf("run must have a run type") + } + + if r.Workspace == "" { + return fmt.Errorf("run must have a target workspace") + } + + return nil +} + +func (r *Run) Summary() string { + return fmt.Sprintf("[%s][%s]", string(r.Type), string(r.Status)) +} diff --git a/pkg/domain/entity/source.go b/pkg/domain/entity/source.go index 9d591e04..335701ee 100644 --- a/pkg/domain/entity/source.go +++ b/pkg/domain/entity/source.go @@ -13,6 +13,8 @@ import ( type Source struct { // ID is the id of the source. ID uint `yaml:"id" json:"id"` + // Name is the name of the source. + Name string `yaml:"name" json:"name"` // SourceProvider is the type of the source provider. SourceProvider constant.SourceProviderType `yaml:"sourceProvider" json:"sourceProvider"` // Remote is the source URL, including scheme. @@ -36,6 +38,10 @@ func (s *Source) Validate() error { return fmt.Errorf("source is nil") } + if s.Name == "" { + return fmt.Errorf("source must have a name") + } + if s.SourceProvider == "" { return fmt.Errorf("source must have a source provider") } diff --git a/pkg/domain/repository/repository.go b/pkg/domain/repository/repository.go index 0b469f18..e0088d8a 100644 --- a/pkg/domain/repository/repository.go +++ b/pkg/domain/repository/repository.go @@ -137,3 +137,18 @@ type ModuleRepository interface { // List retrives all the existing modules. List(ctx context.Context) ([]*entity.Module, error) } + +// RunRepository is an interface that defines the repository operations +// for runs. It follows the principles of domain-driven design (DDD). +type RunRepository interface { + // Create creates a new run. + Create(ctx context.Context, run *entity.Run) error + // Delete deletes a run by its ID. + Delete(ctx context.Context, id uint) error + // Update updates an existing run. + Update(ctx context.Context, run *entity.Run) error + // Get retrieves a run by its ID. + Get(ctx context.Context, id uint) (*entity.Run, error) + // List retrieves all existing run. + List(ctx context.Context, filter *entity.RunFilter) ([]*entity.Run, error) +} diff --git a/pkg/domain/request/execute_request.go b/pkg/domain/request/execute_request.go index 42bfc999..f0129a97 100644 --- a/pkg/domain/request/execute_request.go +++ b/pkg/domain/request/execute_request.go @@ -11,3 +11,24 @@ type StackImportRequest struct { func (payload *StackImportRequest) Decode(r *http.Request) error { return decode(r, payload) } + +type CreateRunRequest struct { + Type string `json:"type"` + StackID uint `json:"stackID"` + Workspace string `json:"workspace"` + ImportedResources StackImportRequest `json:"importedResources"` +} + +type UpdateRunRequest struct { + CreateRunRequest `json:",inline" yaml:",inline"` +} + +type UpdateRunResultRequest struct { + Result string `json:"result"` + Status string `json:"status"` + Logs string `json:"logs"` +} + +func (payload *CreateRunRequest) Decode(r *http.Request) error { + return decode(r, payload) +} diff --git a/pkg/domain/request/source_request.go b/pkg/domain/request/source_request.go index 7e48abff..baa1b14d 100644 --- a/pkg/domain/request/source_request.go +++ b/pkg/domain/request/source_request.go @@ -5,6 +5,8 @@ import "net/http" // CreateSourceRequest represents the create request structure for // source. type CreateSourceRequest struct { + // Name is the name of the source. + Name string `json:"name" binding:"required"` // SourceProvider is the type of the source provider. SourceProvider string `json:"sourceProvider" binding:"required"` // Remote is the source URL, including scheme. @@ -21,17 +23,8 @@ type CreateSourceRequest struct { // source. type UpdateSourceRequest struct { // ID is the id of the source. - ID uint `json:"id" binding:"required"` - // SourceProvider is the type of the source provider. - SourceProvider string `json:"sourceProvider"` - // Remote is the source URL, including scheme. - Remote string `json:"remote"` - // Description is a human-readable description of the source. - Description string `json:"description"` - // Labels are custom labels associated with the source. - Labels []string `json:"labels"` - // Owners is a list of owners for the source. - Owners []string `json:"owners"` + ID uint `json:"id" binding:"required"` + CreateSourceRequest `json:",inline" yaml:",inline"` } func (payload *CreateSourceRequest) Decode(r *http.Request) error { diff --git a/pkg/engine/api/apply.go b/pkg/engine/api/apply.go index f8695492..453981a2 100644 --- a/pkg/engine/api/apply.go +++ b/pkg/engine/api/apply.go @@ -30,6 +30,7 @@ func Apply( o *APIOptions, storage release.Storage, rel *apiv1.Release, + gph *apiv1.Graph, changes *models.Changes, out io.Writer, ) (*apiv1.Release, error) { @@ -135,6 +136,7 @@ func Apply( Stack: changes.Stack(), }, Release: rel, + Graph: gph, }) if v1.IsErr(st) { return nil, fmt.Errorf("apply failed, status:\n%v", st) diff --git a/pkg/engine/api/apply_test.go b/pkg/engine/api/apply_test.go index a17196bc..ed5ab125 100644 --- a/pkg/engine/api/apply_test.go +++ b/pkg/engine/api/apply_test.go @@ -62,7 +62,7 @@ func TestApply(t *testing.T) { changes := models.NewChanges(proj, stack, order) o := &APIOptions{} o.DryRun = true - _, err := Apply(context.TODO(), o, &releasestorages.LocalStorage{}, rel, changes, os.Stdout) + _, err := Apply(context.TODO(), o, &releasestorages.LocalStorage{}, rel, &apiv1.Graph{}, changes, os.Stdout) assert.Nil(t, err) }) mockey.PatchConvey("apply success", t, func() { @@ -86,7 +86,7 @@ func TestApply(t *testing.T) { } changes := models.NewChanges(proj, stack, order) - _, err := Apply(context.TODO(), o, &releasestorages.LocalStorage{}, rel, changes, os.Stdout) + _, err := Apply(context.TODO(), o, &releasestorages.LocalStorage{}, rel, &apiv1.Graph{}, changes, os.Stdout) assert.Nil(t, err) }) mockey.PatchConvey("apply failed", t, func() { @@ -105,8 +105,7 @@ func TestApply(t *testing.T) { }, } changes := models.NewChanges(proj, stack, order) - - _, err := Apply(context.TODO(), o, &releasestorages.LocalStorage{}, rel, changes, os.Stdout) + _, err := Apply(context.TODO(), o, &releasestorages.LocalStorage{}, rel, &apiv1.Graph{}, changes, os.Stdout) assert.NotNil(t, err) }) } diff --git a/pkg/infra/persistence/project_test.go b/pkg/infra/persistence/project_test.go index 6f77f957..3a634ba5 100644 --- a/pkg/infra/persistence/project_test.go +++ b/pkg/infra/persistence/project_test.go @@ -30,6 +30,7 @@ func TestProjectRepository(t *testing.T) { Name: "mockedProject", Source: &entity.Source{ ID: 1, + Name: "mockedSource", SourceProvider: constant.SourceProviderTypeGithub, Remote: mockRemoteURL, }, diff --git a/pkg/infra/persistence/run.go b/pkg/infra/persistence/run.go new file mode 100644 index 00000000..5f399dc8 --- /dev/null +++ b/pkg/infra/persistence/run.go @@ -0,0 +1,121 @@ +package persistence + +import ( + "context" + + "gorm.io/gorm" + "kusionstack.io/kusion/pkg/domain/entity" + "kusionstack.io/kusion/pkg/domain/repository" +) + +// The runRepository type implements the repository.RunRepository interface. +// If the runRepository type does not implement all the methods of the interface, +// the compiler will produce an error. +var _ repository.RunRepository = &runRepository{} + +// runRepository is a repository that stores runs in a gorm database. +type runRepository struct { + // db is the underlying gorm database where runs are stored. + db *gorm.DB +} + +// NewRunRepository creates a new run repository. +func NewRunRepository(db *gorm.DB) repository.RunRepository { + return &runRepository{db: db} +} + +// Create saves a run to the repository. +func (r *runRepository) Create(ctx context.Context, dataEntity *entity.Run) error { + // r.db.AutoMigrate(&RunModel{}) + err := dataEntity.Validate() + if err != nil { + return err + } + + // Map the data from Entity to DO + var dataModel RunModel + err = dataModel.FromEntity(dataEntity) + if err != nil { + return err + } + + return r.db.Transaction(func(tx *gorm.DB) error { + err = tx.WithContext(ctx).Create(&dataModel).Error + if err != nil { + return err + } + + dataEntity.ID = dataModel.ID + + return nil + }) +} + +// Delete removes a run from the repository. +func (r *runRepository) Delete(ctx context.Context, id uint) error { + return r.db.Transaction(func(tx *gorm.DB) error { + var dataModel RunModel + err := tx.WithContext(ctx).First(&dataModel, id).Error + if err != nil { + return err + } + + return tx.WithContext(ctx).Unscoped().Delete(&dataModel).Error + }) +} + +// Update updates an existing run in the repository. +func (r *runRepository) Update(ctx context.Context, dataEntity *entity.Run) error { + // Map the data from Entity to DO + var dataModel RunModel + err := dataModel.FromEntity(dataEntity) + if err != nil { + return err + } + + err = r.db.WithContext(ctx).Updates(&dataModel).Error + if err != nil { + return err + } + + return nil +} + +// Get retrieves a run by its ID. +func (r *runRepository) Get(ctx context.Context, id uint) (*entity.Run, error) { + var dataModel RunModel + err := r.db.WithContext(ctx). + Preload("Stack").Preload("Stack.Project"). + Joins("JOIN Stack ON Stack.id = Run.stack_id"). + Joins("JOIN Project ON Project.id = Stack.project_id"). + First(&dataModel, id).Error + if err != nil { + return nil, err + } + return dataModel.ToEntity() +} + +// List retrieves all runs. +func (r *runRepository) List(ctx context.Context, filter *entity.RunFilter) ([]*entity.Run, error) { + var dataModel []RunModel + runEntityList := make([]*entity.Run, 0) + pattern, args := GetRunQuery(filter) + result := r.db.WithContext(ctx). + Preload("Stack").Preload("Stack.Project"). + Joins("JOIN Stack ON Stack.id = Run.stack_id"). + Joins("JOIN Project ON Project.id = Stack.project_id"). + Joins("JOIN Workspace ON Workspace.name = Run.workspace"). + Where(pattern, args...). + Find(&dataModel) + if result.Error != nil { + return nil, result.Error + } + for _, run := range dataModel { + runEntity, err := run.ToEntity() + if err != nil { + return nil, err + } + runEntityList = append(runEntityList, runEntity) + } + return runEntityList, nil +} diff --git a/pkg/infra/persistence/run_model.go b/pkg/infra/persistence/run_model.go new file mode 100644 index 00000000..48f7f946 --- /dev/null +++ b/pkg/infra/persistence/run_model.go @@ -0,0 +1,90 @@ +package persistence + +import ( + "kusionstack.io/kusion/pkg/domain/constant" + "kusionstack.io/kusion/pkg/domain/entity" + + "gorm.io/gorm" +) + +// RunModel is a DO used to map the entity to the database. +type RunModel struct { + gorm.Model + // RunType is the type of the run. + Type string + // StackID is the stack ID of the run. + StackID uint + // Stack is the stack of the run. + Stack *StackModel `gorm:"foreignKey:ID;references:StackID"` + // Workspace is the target workspace of the run. + Workspace string + // Status is the status of the run. + Status string + // Result is the result of the run. + Result string + // Logs is the logs of the run. + Logs string +} + +// The TableName method returns the name of the database table that the struct is mapped to. +func (m *RunModel) TableName() string { + return "run" +} + +// ToEntity converts the DO to an entity. +func (m *RunModel) ToEntity() (*entity.Run, error) { + if m == nil { + return nil, ErrRunModelNil + } + + runType, err := constant.ParseRunType(m.Type) + if err != nil { + return nil, ErrFailedToGetRunType + } + + runStatus, err := constant.ParseRunStatus(m.Status) + if err != nil { + return nil, ErrFailedToGetRunStatus + } + + stackEntity, err := m.Stack.ToEntity() + if err != nil { + return nil, err + } + + return &entity.Run{ + ID: m.ID, + Type: runType, + Stack: stackEntity, + Workspace: m.Workspace, + Status: runStatus, + Result: m.Result, + // Result: entity.RunResult{}, + Logs: m.Logs, + CreationTimestamp: m.CreatedAt, + UpdateTimestamp: m.UpdatedAt, + }, nil +} + +// FromEntity converts an entity to a DO. +func (m *RunModel) FromEntity(e *entity.Run) error { + if m == nil { + return ErrRunModelNil + } + + if e.Stack != nil { + m.StackID = e.Stack.ID + m.Stack.FromEntity(e.Stack) + } + + m.ID = e.ID + m.Type = string(e.Type) + m.Workspace = e.Workspace + m.Status = string(e.Status) + m.Result = e.Result + m.Logs = e.Logs + m.CreatedAt = e.CreationTimestamp + m.UpdatedAt = e.UpdateTimestamp + + return nil +} diff --git a/pkg/infra/persistence/source_model.go b/pkg/infra/persistence/source_model.go index 7835a46f..825334df 100644 --- a/pkg/infra/persistence/source_model.go +++ b/pkg/infra/persistence/source_model.go @@ -12,7 +12,10 @@ import ( // SourceModel is a DO used to map the entity to the database. type SourceModel struct { gorm.Model + // Name is the name of the source. + Name string `gorm:"index:unique_source,unique"` // SourceProvider is the type of the source provider. + // TODO: remove uk here SourceProvider string `gorm:"index:unique_source,unique"` // Remote is the source URL, including scheme. Remote string `gorm:"index:unique_source,unique"` @@ -53,6 +56,7 @@ func (m *SourceModel) ToEntity() (*entity.Source, error) { return &entity.Source{ ID: m.ID, + Name: m.Name, SourceProvider: sourceProvider, Remote: remote, Description: m.Description, @@ -76,6 +80,7 @@ func (m *SourceModel) FromEntity(e *entity.Source) error { } m.ID = e.ID + m.Name = e.Name m.SourceProvider = string(e.SourceProvider) m.Description = e.Description m.Labels = MultiString(e.Labels) diff --git a/pkg/infra/persistence/source_test.go b/pkg/infra/persistence/source_test.go index 99a5f1d4..21853a11 100644 --- a/pkg/infra/persistence/source_test.go +++ b/pkg/infra/persistence/source_test.go @@ -27,6 +27,7 @@ func TestSourceRepository(t *testing.T) { var ( expectedID, expectedRows uint = 1, 1 actual = entity.Source{ + Name: "mockedSource", SourceProvider: constant.SourceProviderTypeOCI, Remote: mockRemoteURL, Description: "i am a description", diff --git a/pkg/infra/persistence/stack_test.go b/pkg/infra/persistence/stack_test.go index 63b3bce2..bc57711c 100644 --- a/pkg/infra/persistence/stack_test.go +++ b/pkg/infra/persistence/stack_test.go @@ -35,6 +35,7 @@ func TestStackRepository(t *testing.T) { Path: "/path/to/project", Source: &entity.Source{ ID: 1, + Name: "mockedSource", SourceProvider: constant.SourceProviderTypeGithub, Remote: mockRemoteURL, }, diff --git a/pkg/infra/persistence/types.go b/pkg/infra/persistence/types.go index dad9e62c..356f45b1 100644 --- a/pkg/infra/persistence/types.go +++ b/pkg/infra/persistence/types.go @@ -25,4 +25,7 @@ var ( ErrFailedToGetModuleRemote = errors.New("failed to parse module remote") ErrResourceModelNil = errors.New("resource model can't be nil") ErrFailedToGetModuleDocRemote = errors.New("failed to parse module doc remote") + ErrRunModelNil = errors.New("run model can't be nil") + ErrFailedToGetRunType = errors.New("failed to parse run type") + ErrFailedToGetRunStatus = errors.New("failed to parse run status") ) diff --git a/pkg/infra/persistence/util.go b/pkg/infra/persistence/util.go index bb1808f2..3cad9429 100644 --- a/pkg/infra/persistence/util.go +++ b/pkg/infra/persistence/util.go @@ -161,6 +161,24 @@ func GetResourceQuery(filter *entity.ResourceFilter) (string, []interface{}) { return CombineQueryParts(pattern), args } +func GetRunQuery(filter *entity.RunFilter) (string, []interface{}) { + pattern := make([]string, 0) + args := make([]interface{}, 0) + if filter.ProjectID != 0 { + pattern = append(pattern, "Project.ID = ?") + args = append(args, fmt.Sprint(filter.ProjectID)) + } + if filter.StackID != 0 { + pattern = append(pattern, "stack_id = ?") + args = append(args, filter.StackID) + } + if filter.Workspace != "" { + pattern = append(pattern, "Workspace.name = ?") + args = append(args, filter.Workspace) + } + return CombineQueryParts(pattern), args +} + func CombineQueryParts(queryParts []string) string { queryString := "" if len(queryParts) > 0 { @@ -197,5 +215,8 @@ func AutoMigrate(db *gorm.DB) error { if err := db.AutoMigrate(&ModuleModel{}); err != nil { return err } + if err := db.AutoMigrate(&RunModel{}); err != nil { + return err + } return nil } diff --git a/pkg/infra/util/worker/worker.go b/pkg/infra/util/worker/worker.go new file mode 100644 index 00000000..2b3ea466 --- /dev/null +++ b/pkg/infra/util/worker/worker.go @@ -0,0 +1,64 @@ +package worker + +import ( + "sync" +) + +type WorkerPool struct { + tasks chan func() // use channel to store tasks + wg sync.WaitGroup + numAvailableWorkers int // number of available workers + mu sync.Mutex // lock read/write of numAvailableWorkers +} + +func NewWorkerPool(maxConcurrentGoroutines, maxBufferGoroutines int) *WorkerPool { + pool := &WorkerPool{ + tasks: make(chan func(), maxBufferGoroutines), + numAvailableWorkers: maxConcurrentGoroutines, // initialize worker count + } + + for i := 0; i < maxConcurrentGoroutines; i++ { + go func() { + for task := range pool.tasks { + pool.mu.Lock() + pool.numAvailableWorkers-- // lower worker count + pool.mu.Unlock() + + task() // execute the task + + pool.mu.Lock() + pool.numAvailableWorkers++ // increase worker count + pool.mu.Unlock() + } + }() + } + + return pool +} + +// Do add the task to worker pool and return whether it is added to the execution zone or buffer zone +func (p *WorkerPool) Do(task func()) bool { + p.wg.Add(1) + inBufferZone := true + + // check available worker + p.mu.Lock() + if p.numAvailableWorkers > 0 { + inBufferZone = false + } + p.mu.Unlock() + + // place the task into the channel + p.tasks <- func() { + defer p.wg.Done() + task() + } + + return inBufferZone +} + +// Wait for all tasks before closing the channel +func (p *WorkerPool) Wait() { + p.wg.Wait() + close(p.tasks) // close the channel and stop the goroutines +} diff --git a/pkg/server/config.go b/pkg/server/config.go index 0c90ce14..0869bba7 100644 --- a/pkg/server/config.go +++ b/pkg/server/config.go @@ -6,16 +6,18 @@ import ( ) type Config struct { - DB *gorm.DB - DefaultBackend entity.Backend - DefaultSource entity.Source - Port int - AuthEnabled bool - AuthWhitelist []string - AuthKeyType string - MaxConcurrent int - LogFilePath string - AutoMigrate bool + DB *gorm.DB + DefaultBackend entity.Backend + DefaultSource entity.Source + Port int + AuthEnabled bool + AuthWhitelist []string + AuthKeyType string + MaxConcurrent int + MaxAsyncConcurrent int + MaxAsyncBuffer int + LogFilePath string + AutoMigrate bool } func NewConfig() *Config { diff --git a/pkg/server/handler/project/handler_test.go b/pkg/server/handler/project/handler_test.go index 870ceefe..c2177fe5 100644 --- a/pkg/server/handler/project/handler_test.go +++ b/pkg/server/handler/project/handler_test.go @@ -119,8 +119,8 @@ func TestProjectHandler(t *testing.T) { req.Header.Add("Content-Type", "application/json") sqlMock.ExpectQuery("SELECT"). - WillReturnRows(sqlmock.NewRows([]string{"id", "remote", "source_provider"}). - AddRow(1, "https://github.com/test/repo", constant.SourceProviderTypeGithub)) + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "remote", "source_provider"}). + AddRow(1, "test-source", "https://github.com/test/repo", constant.SourceProviderTypeGithub)) sqlMock.ExpectQuery("SELECT"). WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owners"}). AddRow(1, "test-org", owners)) @@ -170,8 +170,8 @@ func TestProjectHandler(t *testing.T) { req.Header.Add("Content-Type", "application/json") sqlMock.ExpectQuery("SELECT"). - WillReturnRows(sqlmock.NewRows([]string{"id", "remote", "source_provider"}). - AddRow(1, "https://github.com/test/repo", constant.SourceProviderTypeGithub)) + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "remote", "source_provider"}). + AddRow(1, "test-source", "https://github.com/test/repo", constant.SourceProviderTypeGithub)) sqlMock.ExpectQuery("SELECT"). WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owners"}). AddRow(1, "test-org", owners)) diff --git a/pkg/server/handler/source/handler_test.go b/pkg/server/handler/source/handler_test.go index 486dc506..2353ce8c 100644 --- a/pkg/server/handler/source/handler_test.go +++ b/pkg/server/handler/source/handler_test.go @@ -102,6 +102,7 @@ func TestSourceHandler(t *testing.T) { // Set request body requestPayload := request.CreateSourceRequest{ // Set your request payload fields here + Name: "test-source", SourceProvider: string(constant.SourceProviderTypeGithub), Remote: "https://github.com/test/remote", } @@ -147,9 +148,11 @@ func TestSourceHandler(t *testing.T) { // Set request body requestPayload := request.UpdateSourceRequest{ // Set your request payload fields here - ID: 1, - SourceProvider: string(constant.SourceProviderTypeGithub), - Remote: "https://github.com/test/updated-remote", + ID: 1, + CreateSourceRequest: request.CreateSourceRequest{ + SourceProvider: string(constant.SourceProviderTypeGithub), + Remote: "https://github.com/test/updated-remote", + }, } reqBody, err := json.Marshal(requestPayload) assert.NoError(t, err) @@ -261,9 +264,11 @@ func TestSourceHandler(t *testing.T) { // Set request body requestPayload := request.UpdateSourceRequest{ // Set your request payload fields here - ID: 1, - SourceProvider: string(constant.SourceProviderTypeGithub), - Remote: "https://github.com/test/updated-remote", + ID: 1, + CreateSourceRequest: request.CreateSourceRequest{ + SourceProvider: string(constant.SourceProviderTypeGithub), + Remote: "https://github.com/test/updated-remote", + }, } reqBody, err := json.Marshal(requestPayload) assert.NoError(t, err) diff --git a/pkg/server/handler/stack/execute.go b/pkg/server/handler/stack/execute.go index 0a34bdee..3af7fb02 100644 --- a/pkg/server/handler/stack/execute.go +++ b/pkg/server/handler/stack/execute.go @@ -1,25 +1,17 @@ package stack import ( - "context" "fmt" "io" "net/http" - "strconv" - "github.com/go-chi/chi/v5" - "github.com/go-chi/httplog/v2" "github.com/go-chi/render" yamlv2 "gopkg.in/yaml.v2" - "kusionstack.io/kusion/pkg/domain/constant" "kusionstack.io/kusion/pkg/domain/request" "kusionstack.io/kusion/pkg/server/handler" stackmanager "kusionstack.io/kusion/pkg/server/manager/stack" - appmiddleware "kusionstack.io/kusion/pkg/server/middleware" - authutil "kusionstack.io/kusion/pkg/server/util/auth" - logutil "kusionstack.io/kusion/pkg/server/util/logging" ) // @Id previewStack @@ -27,17 +19,20 @@ import ( // @Description Preview stack information by stack ID // @Tags stack // @Produce json -// @Param stack_id path int true "Stack ID" -// @Param output query string false "Output format. Choices are: json, default. Default to default output format in Kusion." -// @Param detail query bool false "Show detailed output" -// @Param specID query string false "The Spec ID to use for the preview. Default to the last one generated." -// @Param force query bool false "Force the preview even when the stack is locked" -// @Success 200 {object} models.Changes "Success" -// @Failure 400 {object} error "Bad Request" -// @Failure 401 {object} error "Unauthorized" -// @Failure 429 {object} error "Too Many Requests" -// @Failure 404 {object} error "Not Found" -// @Failure 500 {object} error "Internal Server Error" +// @Param stack_id path int true "Stack ID" +// @Param importedResources body request.StackImportRequest false "The resources to import during the stack preview" +// @Param workspace query string true "The target workspace to preview the spec in." +// @Param importResources query bool false "Import existing resources during the stack preview" +// @Param output query string false "Output format. Choices are: json, default. Default to default output format in Kusion." +// @Param detail query bool false "Show detailed output" +// @Param specID query string false "The Spec ID to use for the preview. Default to the last one generated." +// @Param force query bool false "Force the preview even when the stack is locked" +// @Success 200 {object} models.Changes "Success" +// @Failure 400 {object} error "Bad Request" +// @Failure 401 {object} error "Unauthorized" +// @Failure 429 {object} error "Too Many Requests" +// @Failure 404 {object} error "Not Found" +// @Failure 500 {object} error "Internal Server Error" // @Router /api/v1/stacks/{stack_id}/preview [post] func (h *Handler) PreviewStack() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { @@ -84,6 +79,7 @@ func (h *Handler) PreviewStack() http.HandlerFunc { // @Tags stack // @Produce json // @Param stack_id path int true "Stack ID" +// @Param workspace query string true "The target workspace to preview the spec in." // @Param format query string false "The format to generate the spec in. Choices are: spec. Default to spec." // @Param force query bool false "Force the generate even when the stack is locked" // @Success 200 {object} v1.Spec "Success" @@ -120,16 +116,19 @@ func (h *Handler) GenerateStack() http.HandlerFunc { // @Description Apply stack information by stack ID // @Tags stack // @Produce json -// @Param stack_id path int true "Stack ID" -// @Param specID query string false "The Spec ID to use for the apply. Will generate a new spec if omitted." -// @Param force query bool false "Force the apply even when the stack is locked. May cause concurrency issues!!!" -// @Param dryrun query bool false "Apply in dry-run mode" -// @Success 200 {object} string "Success" -// @Failure 400 {object} error "Bad Request" -// @Failure 401 {object} error "Unauthorized" -// @Failure 429 {object} error "Too Many Requests" -// @Failure 404 {object} error "Not Found" -// @Failure 500 {object} error "Internal Server Error" +// @Param stack_id path int true "Stack ID" +// @Param importedResources body request.StackImportRequest false "The resources to import during the stack preview" +// @Param workspace query string true "The target workspace to preview the spec in." +// @Param importResources query bool false "Import existing resources during the stack preview" +// @Param specID query string false "The Spec ID to use for the apply. Will generate a new spec if omitted." +// @Param force query bool false "Force the apply even when the stack is locked. May cause concurrency issues!!!" +// @Param dryrun query bool false "Apply in dry-run mode" +// @Success 200 {object} string "Success" +// @Failure 400 {object} error "Bad Request" +// @Failure 401 {object} error "Unauthorized" +// @Failure 429 {object} error "Too Many Requests" +// @Failure 404 {object} error "Not Found" +// @Failure 500 {object} error "Internal Server Error" // @Router /api/v1/stacks/{stack_id}/apply [post] func (h *Handler) ApplyStack() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { @@ -185,6 +184,7 @@ func (h *Handler) ApplyStack() http.HandlerFunc { // @Tags stack // @Produce json // @Param stack_id path int true "Stack ID" +// @Param workspace query string true "The target workspace to preview the spec in." // @Param force query bool false "Force the destroy even when the stack is locked. May cause concurrency issues!!!" // @Param dryrun query bool false "Destroy in dry-run mode" // @Success 200 {object} string "Success" @@ -220,49 +220,3 @@ func (h *Handler) DestroyStack() http.HandlerFunc { render.Render(w, r, handler.SuccessResponse(ctx, "destroy completed")) } } - -func requestHelper(r *http.Request) (context.Context, *httplog.Logger, *stackmanager.StackRequestParams, error) { - ctx := r.Context() - stackID := chi.URLParam(r, "stackID") - // Get stack with repository - id, err := strconv.Atoi(stackID) - if err != nil { - return nil, nil, nil, stackmanager.ErrInvalidStackID - } - logger := logutil.GetLogger(ctx) - // Get Params - outputParam := r.URL.Query().Get("output") - detailParam, _ := strconv.ParseBool(r.URL.Query().Get("detail")) - dryrunParam, _ := strconv.ParseBool(r.URL.Query().Get("dryrun")) - forceParam, _ := strconv.ParseBool(r.URL.Query().Get("force")) - importResourcesParam, _ := strconv.ParseBool(r.URL.Query().Get("importResources")) - specIDParam := r.URL.Query().Get("specID") - // TODO: Should match automatically eventually??? - workspaceParam := r.URL.Query().Get("workspace") - operatorParam, err := authutil.GetSubjectFromUnverifiedJWTToken(ctx, r) - // fall back to x-kusion-user if operator is not parsed from cookie - if operatorParam == "" || err != nil { - operatorParam = appmiddleware.GetUserID(ctx) - if operatorParam == "" { - operatorParam = constant.DefaultUser - } - } - if workspaceParam == "" { - workspaceParam = constant.DefaultWorkspace - } - executeParams := stackmanager.StackExecuteParams{ - Detail: detailParam, - Dryrun: dryrunParam, - Force: forceParam, - SpecID: specIDParam, - ImportResources: importResourcesParam, - } - params := stackmanager.StackRequestParams{ - StackID: uint(id), - Workspace: workspaceParam, - Format: outputParam, - Operator: operatorParam, - ExecuteParams: executeParams, - } - return ctx, logger, ¶ms, nil -} diff --git a/pkg/server/handler/stack/execute_async.go b/pkg/server/handler/stack/execute_async.go new file mode 100644 index 00000000..d07fce82 --- /dev/null +++ b/pkg/server/handler/stack/execute_async.go @@ -0,0 +1,422 @@ +package stack + +import ( + "fmt" + "io" + "net/http" + "time" + + "github.com/go-chi/render" + yamlv2 "gopkg.in/yaml.v2" + + apiv1 "kusionstack.io/kusion/pkg/apis/api.kusion.io/v1" + "kusionstack.io/kusion/pkg/domain/constant" + "kusionstack.io/kusion/pkg/domain/request" + "kusionstack.io/kusion/pkg/engine/operation/models" + "kusionstack.io/kusion/pkg/server/handler" + stackmanager "kusionstack.io/kusion/pkg/server/manager/stack" + + logutil "kusionstack.io/kusion/pkg/server/util/logging" +) + +// @Id previewStackAsync +// @Summary Asynchronously preview stack +// @Description Start a run and asynchronously preview stack changes by stack ID +// @Tags stack +// @Produce json +// @Param stack_id path int true "Stack ID" +// @Param importedResources body request.StackImportRequest false "The resources to import during the stack preview" +// @Param workspace query string true "The target workspace to preview the spec in." +// @Param importResources query bool false "Import existing resources during the stack preview" +// @Param output query string false "Output format. Choices are: json, default. Default to default output format in Kusion." +// @Param detail query bool false "Show detailed output" +// @Param specID query string false "The Spec ID to use for the preview. Default to the last one generated." +// @Param force query bool false "Force the preview even when the stack is locked" +// @Success 200 {object} entity.Run "Success" +// @Failure 400 {object} error "Bad Request" +// @Failure 401 {object} error "Unauthorized" +// @Failure 429 {object} error "Too Many Requests" +// @Failure 404 {object} error "Not Found" +// @Failure 500 {object} error "Internal Server Error" +// @Router /api/v1/stacks/{stack_id}/preview [post] +func (h *Handler) PreviewStackAsync() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Getting stuff from context + ctx, logger, params, err := requestHelper(r) + if err != nil { + render.Render(w, r, handler.FailureResponse(ctx, err)) + return + } + logger.Info("Previewing stack asynchronously...", "stackID", params.StackID) + + var requestPayload request.CreateRunRequest + if err := requestPayload.Decode(r); err != nil { + if err == io.EOF { + render.Render(w, r, handler.FailureResponse(ctx, fmt.Errorf("request body should not be empty when importResources is set to true"))) + return + } else { + render.Render(w, r, handler.FailureResponse(ctx, err)) + return + } + } + + // Create a Run object in database and start background task + runEntity, err := h.stackManager.CreateRun(ctx, requestPayload) + if err != nil { + render.Render(w, r, handler.FailureResponse(ctx, err)) + return + } + + runLogger := logutil.GetRunLogger(ctx) + runLogger.Info("Starting previewing stack in StackManager ... This is a preview run.", "runID", runEntity.ID) + + // Starts a safe goroutine using given recover handler + inBufferZone := h.workerPool.Do(func() { + // defer safe.HandleCrash(aciLoggingRecoverHandler(h.aciClient, &req, log)) + logger.Info("Async preview in progress") + var previewChanges any + newCtx, cancel := CopyToNewContextWithTimeout(ctx, constant.RunTimeOut) + defer cancel() // make sure the context is canceled to free resources + + // update status of the run when exiting the async run + defer func() { + select { + case <-newCtx.Done(): + logger.Info("preview execution timed out", "stackID", params.StackID, "time", time.Now(), "timeout", newCtx.Err()) + h.setRunToCancelled(newCtx, runEntity.ID) + default: + if err != nil { + logger.Info("preview failed for stack", "stackID", params.StackID, "time", time.Now()) + h.setRunToFailed(newCtx, runEntity.ID) + } else { + logger.Info("preview completed for stack", "stackID", params.StackID, "time", time.Now()) + if pc, ok := previewChanges.(*models.Changes); ok { + h.setRunToSuccess(newCtx, runEntity.ID, pc) + } else { + logger.Error("Error casting preview changes to models.Changes", "error", "casting error") + h.setRunToFailed(newCtx, runEntity.ID) + } + } + } + }() + + // Call preview stack + changes, err := h.stackManager.PreviewStack(newCtx, params, requestPayload.ImportedResources) + if err != nil { + // render.Render(w, r, handler.FailureResponse(ctx, err)) + logger.Error("Error previewing stack", "error", err) + return + } + + previewChanges, err = stackmanager.ProcessChanges(newCtx, w, changes, params.Format, params.ExecuteParams.Detail) + if err != nil { + // render.Render(w, r, handler.FailureResponse(ctx, err)) + logger.Error("Error processing preview changes", "error", err) + return + } + }) + defer func() { + if inBufferZone { + logger.Info("The task is in the buffer zone, waiting for an available worker") + h.setRunToQueued(ctx, runEntity.ID) + } + }() + render.Render(w, r, handler.SuccessResponse(ctx, runEntity)) + } +} + +// @Id applyStackAsync +// @Summary Asynchronously apply stack +// @Description Start a run and asynchronously apply stack changes by stack ID +// @Tags stack +// @Produce json +// @Param stack_id path int true "Stack ID" +// @Param importedResources body request.StackImportRequest false "The resources to import during the stack preview" +// @Param workspace query string true "The target workspace to preview the spec in." +// @Param importResources query bool false "Import existing resources during the stack preview" +// @Param specID query string false "The Spec ID to use for the apply. Will generate a new spec if omitted." +// @Param force query bool false "Force the apply even when the stack is locked. May cause concurrency issues!!!" +// @Param dryrun query bool false "Apply in dry-run mode" +// @Success 200 {object} entity.Run "Success" +// @Failure 400 {object} error "Bad Request" +// @Failure 401 {object} error "Unauthorized" +// @Failure 429 {object} error "Too Many Requests" +// @Failure 404 {object} error "Not Found" +// @Failure 500 {object} error "Internal Server Error" +// @Router /api/v1/stacks/{stack_id}/apply/async [post] +func (h *Handler) ApplyStackAsync() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Getting stuff from context + ctx, logger, params, err := requestHelper(r) + if err != nil { + render.Render(w, r, handler.FailureResponse(ctx, err)) + return + } + logger.Info("Applying stack asynchronously...", "stackID", params.StackID) + + var requestPayload request.CreateRunRequest + if err := requestPayload.Decode(r); err != nil { + if err == io.EOF { + render.Render(w, r, handler.FailureResponse(ctx, fmt.Errorf("request body should not be empty when importResources is set to true"))) + return + } else { + render.Render(w, r, handler.FailureResponse(ctx, err)) + return + } + } + + // Create a Run object in database and start background task + runEntity, err := h.stackManager.CreateRun(ctx, requestPayload) + if err != nil { + render.Render(w, r, handler.FailureResponse(ctx, err)) + return + } + + runLogger := logutil.GetRunLogger(ctx) + runLogger.Info("Starting applying stack in StackManager ... This is an apply run.", "runID", runEntity.ID) + + // Starts a safe goroutine using given recover handler + inBufferZone := h.workerPool.Do(func() { + // defer safe.HandleCrash(aciLoggingRecoverHandler(h.aciClient, &req, log)) + logger.Info("Async apply in progress") + newCtx, cancel := CopyToNewContextWithTimeout(ctx, constant.RunTimeOut) + defer cancel() // make sure the context is canceled to free resources + + // update status of the run when exiting the async run + defer func() { + select { + case <-newCtx.Done(): + logger.Info("apply execution timed out", "stackID", params.StackID, "time", time.Now(), "timeout", newCtx.Err()) + h.setRunToCancelled(newCtx, runEntity.ID) + default: + if err != nil { + logger.Info("apply failed for stack", "stackID", params.StackID, "time", time.Now()) + h.setRunToFailed(newCtx, runEntity.ID) + } else { + logger.Info("apply completed for stack", "stackID", params.StackID, "time", time.Now()) + h.setRunToSuccess(newCtx, runEntity.ID, "apply completed") + } + } + }() + + // call apply stack + err = h.stackManager.ApplyStack(newCtx, params, requestPayload.ImportedResources) + if err != nil { + if err == stackmanager.ErrDryrunDestroy { + render.Render(w, r, handler.SuccessResponse(ctx, "Dry-run mode enabled, the above resources will be applied if dryrun is set to false")) + return + } else { + // render.Render(w, r, handler.FailureResponse(ctx, err)) + logger.Error("Error applying stack", "error", err) + return + } + } + }) + + defer func() { + if inBufferZone { + logger.Info("The task is in the buffer zone, waiting for an available worker") + h.setRunToQueued(ctx, runEntity.ID) + } + }() + render.Render(w, r, handler.SuccessResponse(ctx, runEntity)) + // TODO: How to implement watch? + // if o.Watch { + // fmt.Println("Start watching changes ...") + // if err = Watch(o, sp, changes); err != nil { + // return err + // } + // } + } +} + +// @Id generateStackAsync +// @Summary Asynchronously generate stack +// @Description Start a run and asynchronously generate stack spec by stack ID +// @Tags stack +// @Produce json +// @Param stack_id path int true "Stack ID" +// @Param workspace query string true "The target workspace to preview the spec in." +// @Param format query string false "The format to generate the spec in. Choices are: spec. Default to spec." +// @Param force query bool false "Force the generate even when the stack is locked" +// @Success 200 {object} v1.Spec "Success" +// @Failure 400 {object} error "Bad Request" +// @Failure 401 {object} error "Unauthorized" +// @Failure 429 {object} error "Too Many Requests" +// @Failure 404 {object} error "Not Found" +// @Failure 500 {object} error "Internal Server Error" +// @Router /api/v1/stacks/{stack_id}/generate/async [post] +func (h *Handler) GenerateStackAsync() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Getting stuff from context + ctx, logger, params, err := requestHelper(r) + if err != nil { + render.Render(w, r, handler.FailureResponse(ctx, err)) + return + } + logger.Info("Generating stack asynchronously...", "stackID", params.StackID) + + var requestPayload request.CreateRunRequest + if err := requestPayload.Decode(r); err != nil { + if err == io.EOF { + render.Render(w, r, handler.FailureResponse(ctx, fmt.Errorf("request body should not be empty when importResources is set to true"))) + return + } else { + render.Render(w, r, handler.FailureResponse(ctx, err)) + return + } + } + + // Create a Run object in database and start background task + runEntity, err := h.stackManager.CreateRun(ctx, requestPayload) + if err != nil { + render.Render(w, r, handler.FailureResponse(ctx, err)) + return + } + + runLogger := logutil.GetRunLogger(ctx) + runLogger.Info("Starting generating stack in StackManager ... This is a generate run.", "runID", runEntity.ID) + + // Starts a safe goroutine using given recover handler + inBufferZone := h.workerPool.Do(func() { + // defer safe.HandleCrash(aciLoggingRecoverHandler(h.aciClient, &req, log)) + logger.Info("Async generate in progress") + newCtx, cancel := CopyToNewContextWithTimeout(ctx, constant.RunTimeOut) + var sp *apiv1.Spec + defer cancel() // make sure the context is canceled to free resources + + // update status of the run when exiting the async run + defer func() { + select { + case <-newCtx.Done(): + logger.Info("generate execution timed out", "stackID", params.StackID, "time", time.Now(), "timeout", newCtx.Err()) + h.setRunToCancelled(newCtx, runEntity.ID) + default: + if err != nil { + logger.Info("generate failed for stack", "stackID", params.StackID, "time", time.Now()) + h.setRunToFailed(newCtx, runEntity.ID) + } else { + logger.Info("generate completed for stack", "stackID", params.StackID, "time", time.Now()) + if yaml, err := yamlv2.Marshal(sp); err == nil { + h.setRunToSuccess(newCtx, runEntity.ID, string(yaml)) + } else { + logger.Error("Error marshalling generated spec", "error", err) + h.setRunToFailed(newCtx, runEntity.ID) + } + } + } + }() + + // Call generate stack + _, sp, err := h.stackManager.GenerateSpec(newCtx, params) + if err != nil { + // render.Render(w, r, handler.FailureResponse(ctx, err)) + logger.Error("Error generating stack", "error", err) + return + } + }) + + defer func() { + if inBufferZone { + logger.Info("The task is in the buffer zone, waiting for an available worker") + h.setRunToQueued(ctx, runEntity.ID) + } + }() + render.Render(w, r, handler.SuccessResponse(ctx, runEntity)) + } +} + +// @Id destroyStackAsync +// @Summary Asynchronously destroy stack +// @Description Start a run and asynchronously destroy stack resources by stack ID +// @Tags stack +// @Produce json +// @Param stack_id path int true "Stack ID" +// @Param workspace query string true "The target workspace to preview the spec in." +// @Param force query bool false "Force the destroy even when the stack is locked. May cause concurrency issues!!!" +// @Param dryrun query bool false "Destroy in dry-run mode" +// @Success 200 {object} string "Success" +// @Failure 400 {object} error "Bad Request" +// @Failure 401 {object} error "Unauthorized" +// @Failure 429 {object} error "Too Many Requests" +// @Failure 404 {object} error "Not Found" +// @Failure 500 {object} error "Internal Server Error" +// @Router /api/v1/stacks/{stack_id}/destroy/async [post] +func (h *Handler) DestroyStackAsync() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Getting stuff from context + ctx, logger, params, err := requestHelper(r) + if err != nil { + render.Render(w, r, handler.FailureResponse(ctx, err)) + return + } + logger.Info("Destroying stack asynchronously...", "stackID", params.StackID) + + var requestPayload request.CreateRunRequest + if err := requestPayload.Decode(r); err != nil { + if err == io.EOF { + render.Render(w, r, handler.FailureResponse(ctx, fmt.Errorf("request body should not be empty when importResources is set to true"))) + return + } else { + render.Render(w, r, handler.FailureResponse(ctx, err)) + return + } + } + + // Create a Run object in database and start background task + runEntity, err := h.stackManager.CreateRun(ctx, requestPayload) + if err != nil { + render.Render(w, r, handler.FailureResponse(ctx, err)) + return + } + + runLogger := logutil.GetRunLogger(ctx) + runLogger.Info("Starting destroying stack in StackManager ... This is a destroy run.", "runID", runEntity.ID) + + // Starts a safe goroutine using given recover handler + inBufferZone := h.workerPool.Do(func() { + // defer safe.HandleCrash(aciLoggingRecoverHandler(h.aciClient, &req, log)) + logger.Info("Async destroy in progress") + newCtx, cancel := CopyToNewContextWithTimeout(ctx, constant.RunTimeOut) + defer cancel() // make sure the context is canceled to free resources + + // update status of the run when exiting the async run + defer func() { + select { + case <-newCtx.Done(): + logger.Info("destroy execution timed out", "stackID", params.StackID, "time", time.Now(), "timeout", newCtx.Err()) + h.setRunToCancelled(newCtx, runEntity.ID) + default: + if err != nil { + logger.Info("destroy failed for stack", "stackID", params.StackID, "time", time.Now()) + h.setRunToFailed(newCtx, runEntity.ID) + } else { + logger.Info("destroy completed for stack", "stackID", params.StackID, "time", time.Now()) + h.setRunToSuccess(newCtx, runEntity.ID, "destroy completed") + } + } + }() + + err = h.stackManager.DestroyStack(newCtx, params, w) + if err != nil { + if err == stackmanager.ErrDryrunDestroy { + // render.Render(w, r, handler.SuccessResponse(ctx, "Dry-run mode enabled, the above resources will be destroyed if dryrun is set to false")) + logger.Info("Dry-run mode enabled, the above resources will be destroyed if dryrun is set to false") + return + } else { + // render.Render(w, r, handler.FailureResponse(ctx, err)) + logger.Error("Error destroying stack", "error", err) + return + } + } + }) + + defer func() { + if inBufferZone { + logger.Info("The task is in the buffer zone, waiting for an available worker") + h.setRunToQueued(ctx, runEntity.ID) + } + }() + render.Render(w, r, handler.SuccessResponse(ctx, runEntity)) + } +} diff --git a/pkg/server/handler/stack/handler.go b/pkg/server/handler/stack/handler.go index 5b820921..6d909075 100644 --- a/pkg/server/handler/stack/handler.go +++ b/pkg/server/handler/stack/handler.go @@ -169,7 +169,6 @@ func (h *Handler) GetStack() http.HandlerFunc { // @Param projectName query string false "ProjectName to filter stacks by. Default to all" // @Param cloud query string false "Cloud to filter stacks by. Default to all" // @Param env query string false "Environment to filter stacks by. Default to all" -// @Param getLastSyncedBase query bool false "Whether to get last synced base revision. Default to false" // @Success 200 {object} []entity.Stack "Success" // @Failure 400 {object} error "Bad Request" // @Failure 401 {object} error "Unauthorized" diff --git a/pkg/server/handler/stack/handler_test.go b/pkg/server/handler/stack/handler_test.go index b7220b25..4f1aae4f 100644 --- a/pkg/server/handler/stack/handler_test.go +++ b/pkg/server/handler/stack/handler_test.go @@ -122,8 +122,8 @@ func TestStackHandler(t *testing.T) { req.Header.Add("Content-Type", "application/json") sqlMock.ExpectQuery("SELECT"). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "path", "Organization__id", "Organization__name", "Organization__owners", "Source__id", "Source__remote", "Source__source_provider"}). - AddRow(1, projectName, projectPath, 1, "test-org", owners, 1, "https://github.com/test/repo", constant.SourceProviderTypeGithub)) + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "path", "Organization__id", "Organization__name", "Organization__owners", "Source__id", "Source__name", "Source__remote", "Source__source_provider"}). + AddRow(1, projectName, projectPath, 1, "test-org", owners, 1, "test-source", "https://github.com/test/repo", constant.SourceProviderTypeGithub)) sqlMock.ExpectBegin() sqlMock.ExpectExec("INSERT"). WillReturnResult(sqlmock.NewResult(int64(1), int64(1))) @@ -326,8 +326,9 @@ func setupTest(t *testing.T) (sqlmock.Sqlmock, *gorm.DB, *httptest.ResponseRecor projectRepo := persistence.NewProjectRepository(fakeGDB) workspaceRepo := persistence.NewWorkspaceRepository(fakeGDB) resourceRepo := persistence.NewResourceRepository(fakeGDB) + runRepo := persistence.NewRunRepository(fakeGDB) stackHandler := &Handler{ - stackManager: stackmanager.NewStackManager(stackRepo, projectRepo, workspaceRepo, resourceRepo, entity.Backend{}, constant.MaxConcurrent), + stackManager: stackmanager.NewStackManager(stackRepo, projectRepo, workspaceRepo, resourceRepo, runRepo, entity.Backend{}, constant.MaxConcurrent), } recorder := httptest.NewRecorder() return sqlMock, fakeGDB, recorder, stackHandler diff --git a/pkg/server/handler/stack/run.go b/pkg/server/handler/stack/run.go new file mode 100644 index 00000000..ed7f61e6 --- /dev/null +++ b/pkg/server/handler/stack/run.go @@ -0,0 +1,133 @@ +package stack + +import ( + "encoding/json" + "net/http" + + "github.com/go-chi/render" + "kusionstack.io/kusion/pkg/server/handler" + logutil "kusionstack.io/kusion/pkg/server/util/logging" +) + +// @Id getRun +// @Summary Get run +// @Description Get run information by run ID +// @Tags run +// @Produce json +// @Param run path int true "Run ID" +// @Success 200 {object} entity.Run "Success" +// @Failure 400 {object} error "Bad Request" +// @Failure 401 {object} error "Unauthorized" +// @Failure 429 {object} error "Too Many Requests" +// @Failure 404 {object} error "Not Found" +// @Failure 500 {object} error "Internal Server Error" +// @Router /api/v1/runs/{run_id} [get] +func (h *Handler) GetRun() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Getting stuff from context + ctx, logger, params, err := runRequestHelper(r) + if err != nil { + render.Render(w, r, handler.FailureResponse(ctx, err)) + return + } + logger.Info("Getting run...", "runID", params.RunID) + + existingEntity, err := h.stackManager.GetRunByID(ctx, params.RunID) + handler.HandleResult(w, r, ctx, err, existingEntity) + } +} + +// @Id getRunResult +// @Summary Get run result +// @Description Get run result by run ID +// @Tags run +// @Produce json +// @Param run path int true "Run ID" +// @Success 200 {object} entity.Run "Success" +// @Failure 400 {object} error "Bad Request" +// @Failure 401 {object} error "Unauthorized" +// @Failure 429 {object} error "Too Many Requests" +// @Failure 404 {object} error "Not Found" +// @Failure 500 {object} error "Internal Server Error" +// @Router /api/v1/runs/{run_id}/result [get] +func (h *Handler) GetRunResult() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Getting stuff from context + ctx, logger, params, err := runRequestHelper(r) + if err != nil { + render.Render(w, r, handler.FailureResponse(ctx, err)) + return + } + logger.Info("Getting run...", "runID", params.RunID) + + existingEntity, err := h.stackManager.GetRunByID(ctx, params.RunID) + if err != nil { + handler.HandleResult(w, r, ctx, err, existingEntity) + return + } + var resultJSON any + err = json.Unmarshal([]byte(existingEntity.Result), &resultJSON) + handler.HandleResult(w, r, ctx, err, resultJSON) + } +} + +// @Id listRun +// @Summary List runs +// @Description List all runs +// @Tags stack +// @Produce json +// @Param projectID query uint false "ProjectID to filter runs by. Default to all" +// @Param orgID query uint false "OrgID to filter runs by. Default to all" +// @Param projectName query string false "ProjectName to filter runs by. Default to all" +// @Param cloud query string false "Cloud to filter runs by. Default to all" +// @Param env query string false "Environment to filter runs by. Default to all" +// @Success 200 {object} []entity.Stack "Success" +// @Failure 400 {object} error "Bad Request" +// @Failure 401 {object} error "Unauthorized" +// @Failure 429 {object} error "Too Many Requests" +// @Failure 404 {object} error "Not Found" +// @Failure 500 {object} error "Internal Server Error" +// @Router /api/v1/runs [get] +func (h *Handler) ListRuns() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Getting stuff from context + ctx := r.Context() + logger := logutil.GetLogger(ctx) + logger.Info("Listing runs...") + + projectIDParam := r.URL.Query().Get("projectID") + stackIDParam := r.URL.Query().Get("stackID") + workspaceParam := r.URL.Query().Get("workspace") + + filter, err := h.stackManager.BuildRunFilter(ctx, projectIDParam, stackIDParam, workspaceParam) + if err != nil { + render.Render(w, r, handler.FailureResponse(ctx, err)) + return + } + + runEntities, err := h.stackManager.ListRuns(ctx, filter) + handler.HandleResult(w, r, ctx, err, runEntities) + } +} + +// TODO: StreamRunLogs to stream logs from a run using SSE +// func StreamRunLogs(c echo.Context) error { +// // 设置 SSE headers +// c.Response().Header().Set("Content-Type", "text/event-stream") +// c.Response().Header().Set("Cache-Control", "no-cache") +// c.Response().Header().Set("Connection", "keep-alive") + +// // id := c.Param("id") +// logs := []string{"log1", "log2", "log3"} + +// for { +// if len(logs) > 0 { +// for _, logMessage := range logs { +// fmt.Fprintf(c.Response().Writer, "data: %s\n\n", logMessage) +// } +// logs = nil +// c.Response().Flush() +// } +// time.Sleep(1 * time.Second) +// } +// } diff --git a/pkg/server/handler/stack/types.go b/pkg/server/handler/stack/types.go index 5f7c41f5..cdfa7e2f 100644 --- a/pkg/server/handler/stack/types.go +++ b/pkg/server/handler/stack/types.go @@ -1,25 +1,28 @@ package stack import ( + worker "kusionstack.io/kusion/pkg/infra/util/worker" stackmanager "kusionstack.io/kusion/pkg/server/manager/stack" ) func NewHandler( stackManager *stackmanager.StackManager, + maxAsyncConcurrent int, + maxAsyncBuffer int, ) (*Handler, error) { return &Handler{ stackManager: stackManager, + workerPool: worker.NewWorkerPool(maxAsyncConcurrent, maxAsyncBuffer), }, nil } type Handler struct { stackManager *stackmanager.StackManager + workerPool *worker.WorkerPool } -type StackRequestParams struct { - StackID uint - Workspace string - Format string - Detail bool - Dryrun bool +// TODO: graceful shutdown of worker pool when exiting +// Capture sigterm and sigint signals to shutdown the worker pool +func (h *Handler) Shutdown() { + h.workerPool.Wait() // wait for all workers to finish } diff --git a/pkg/server/handler/stack/utils.go b/pkg/server/handler/stack/utils.go new file mode 100644 index 00000000..8c9471c8 --- /dev/null +++ b/pkg/server/handler/stack/utils.go @@ -0,0 +1,171 @@ +package stack + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "strconv" + "time" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/httplog/v2" + "kusionstack.io/kusion/pkg/domain/constant" + "kusionstack.io/kusion/pkg/domain/request" + stackmanager "kusionstack.io/kusion/pkg/server/manager/stack" + appmiddleware "kusionstack.io/kusion/pkg/server/middleware" + + authutil "kusionstack.io/kusion/pkg/server/util/auth" + logutil "kusionstack.io/kusion/pkg/server/util/logging" +) + +func (h *Handler) setRunToSuccess(ctx context.Context, runID uint, result any) { + logger := logutil.GetLogger(ctx) + runLogs := logutil.GetRunLoggerBuffer(ctx) + resultBytes, err := json.Marshal(result) + if err != nil { + logger.Error("Error marshalling preview changes", "error", err) + return + } + logger.Info("Result", "result", string(resultBytes)) + // Update the Run object in database to include the preview result + updateRunResultPayload := request.UpdateRunResultRequest{ + Result: string(resultBytes), + Status: string(constant.RunStatusSucceeded), + Logs: runLogs.String(), + } + _, err = h.stackManager.UpdateRunResultAndStatusByID(ctx, runID, updateRunResultPayload) + if err != nil { + logger.Error("Error updating run result after success", "error", err) + return + } +} + +func (h *Handler) setRunToFailed(ctx context.Context, runID uint) { + logger := logutil.GetLogger(ctx) + runLogs := logutil.GetRunLoggerBuffer(ctx) + updateRunResultPayload := request.UpdateRunResultRequest{ + Result: "", + Status: string(constant.RunStatusFailed), + Logs: runLogs.String(), + } + _, err := h.stackManager.UpdateRunResultAndStatusByID(ctx, runID, updateRunResultPayload) + if err != nil { + logger.Error("Error updating run result after failure", "error", err) + } +} + +func (h *Handler) setRunToCancelled(ctx context.Context, runID uint) { + logger := logutil.GetLogger(ctx) + runLogs := logutil.GetRunLoggerBuffer(ctx) + updateRunResultPayload := request.UpdateRunResultRequest{ + Result: "", + Status: string(constant.RunStatusCancelled), + Logs: runLogs.String(), + } + newCtx := CopyToNewContext(ctx) + _, err := h.stackManager.UpdateRunResultAndStatusByID(newCtx, runID, updateRunResultPayload) + if err != nil { + logger.Error("Error updating run result after timeout", "error", err) + } +} + +func (h *Handler) setRunToQueued(ctx context.Context, runID uint) { + logger := logutil.GetLogger(ctx) + runLogs := logutil.GetRunLoggerBuffer(ctx) + updateRunResultPayload := request.UpdateRunResultRequest{ + Result: "", + Status: string(constant.RunStatusQueued), + Logs: runLogs.String(), + } + newCtx := CopyToNewContext(ctx) + _, err := h.stackManager.UpdateRunResultAndStatusByID(newCtx, runID, updateRunResultPayload) + if err != nil { + logger.Error("Error updating run result after queueing", "error", err) + } +} + +func requestHelper(r *http.Request) (context.Context, *httplog.Logger, *stackmanager.StackRequestParams, error) { + ctx := r.Context() + stackID := chi.URLParam(r, "stackID") + // Get stack with repository + id, err := strconv.Atoi(stackID) + if err != nil { + return nil, nil, nil, stackmanager.ErrInvalidStackID + } + logger := logutil.GetLogger(ctx) + // Get Params + outputParam := r.URL.Query().Get("output") + detailParam, _ := strconv.ParseBool(r.URL.Query().Get("detail")) + dryrunParam, _ := strconv.ParseBool(r.URL.Query().Get("dryrun")) + forceParam, _ := strconv.ParseBool(r.URL.Query().Get("force")) + noCacheParam, _ := strconv.ParseBool(r.URL.Query().Get("noCache")) + importResourcesParam, _ := strconv.ParseBool(r.URL.Query().Get("importResources")) + specIDParam := r.URL.Query().Get("specID") + // TODO: Should match automatically eventually??? + workspaceParam := r.URL.Query().Get("workspace") + operatorParam, err := authutil.GetSubjectFromUnverifiedJWTToken(ctx, r) + // fall back to x-kusion-user if operator is not parsed from cookie + if operatorParam == "" || err != nil { + operatorParam = appmiddleware.GetUserID(ctx) + if operatorParam == "" { + operatorParam = constant.DefaultUser + } + } + if workspaceParam == "" { + workspaceParam = constant.DefaultWorkspace + } + executeParams := stackmanager.StackExecuteParams{ + Detail: detailParam, + Dryrun: dryrunParam, + Force: forceParam, + SpecID: specIDParam, + ImportResources: importResourcesParam, + NoCache: noCacheParam, + } + params := stackmanager.StackRequestParams{ + StackID: uint(id), + Workspace: workspaceParam, + Format: outputParam, + Operator: operatorParam, + ExecuteParams: executeParams, + } + return ctx, logger, ¶ms, nil +} + +func runRequestHelper(r *http.Request) (context.Context, *httplog.Logger, *stackmanager.RunRequestParams, error) { + ctx := r.Context() + runID := chi.URLParam(r, "runID") + // Get stack with repository + id, err := strconv.Atoi(runID) + if err != nil { + return nil, nil, nil, stackmanager.ErrInvalidRunID + } + logger := logutil.GetLogger(ctx) + params := stackmanager.RunRequestParams{ + RunID: uint(id), + } + return ctx, logger, ¶ms, nil +} + +func CopyToNewContext(ctx context.Context) context.Context { + newCtx := context.Background() + newCtx = context.WithValue(newCtx, appmiddleware.TraceIDKey, appmiddleware.GetTraceID(ctx)) + newCtx = context.WithValue(newCtx, appmiddleware.UserIDKey, appmiddleware.GetUserID(ctx)) + if logger, ok := ctx.Value(appmiddleware.APILoggerKey).(*httplog.Logger); ok { + newCtx = context.WithValue(newCtx, appmiddleware.APILoggerKey, logger) + } + if runLogger, ok := ctx.Value(appmiddleware.RunLoggerKey).(*httplog.Logger); ok { + newCtx = context.WithValue(newCtx, appmiddleware.RunLoggerKey, runLogger) + } + if runLoggerBuffer, ok := ctx.Value(appmiddleware.RunLoggerBufferKey).(*bytes.Buffer); ok { + newCtx = context.WithValue(newCtx, appmiddleware.RunLoggerBufferKey, runLoggerBuffer) + } + return newCtx +} + +func CopyToNewContextWithTimeout(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { + newCtx := CopyToNewContext(ctx) + newCtxWithTimeout, cancel := context.WithTimeout(newCtx, timeout) + return newCtxWithTimeout, cancel +} diff --git a/pkg/server/manager/stack/execute.go b/pkg/server/manager/stack/execute.go index 06741577..113df23a 100644 --- a/pkg/server/manager/stack/execute.go +++ b/pkg/server/manager/stack/execute.go @@ -14,11 +14,12 @@ import ( "kusionstack.io/kusion/pkg/domain/constant" "kusionstack.io/kusion/pkg/domain/request" "kusionstack.io/kusion/pkg/engine/release" + "kusionstack.io/kusion/pkg/engine/resource/graph" engineapi "kusionstack.io/kusion/pkg/engine/api" + sourceapi "kusionstack.io/kusion/pkg/engine/api/source" "kusionstack.io/kusion/pkg/engine/operation/models" - sourceapi "kusionstack.io/kusion/pkg/engine/api/source" appmiddleware "kusionstack.io/kusion/pkg/server/middleware" logutil "kusionstack.io/kusion/pkg/server/util/logging" ) @@ -72,29 +73,24 @@ func (m *StackManager) GenerateSpec(ctx context.Context, params *StackRequestPar return "", nil, err } - // Build API inputs - // get project to get source and workdir - projectEntity, err := m.projectRepo.Get(ctx, stackEntity.Project.ID) + directory, workDir, err := m.GetWorkdirAndDirectory(ctx, params, stackEntity) if err != nil { return "", nil, err } - - directory, workDir, err := GetWorkDirFromSource(ctx, stackEntity, projectEntity) - logger.Info("workDir derived", "workDir", workDir) - logger.Info("directory derived", "directory", directory) - stack.Path = workDir - if err != nil { - return "", nil, err - } + + // Cleanup + defer func() { + if params.ExecuteParams.NoCache { + sourceapi.Cleanup(ctx, directory) + } + }() stackEntity.SyncState = constant.StackStateGenerated err = m.stackRepo.Update(ctx, stackEntity) if err != nil { return "", nil, err } - // Cleanup - defer sourceapi.Cleanup(ctx, directory) // Generate spec sp, err := engineapi.GenerateSpecWithSpinner(project, stack, ws, true) @@ -116,6 +112,7 @@ func (m *StackManager) PreviewStack(ctx context.Context, params *StackRequestPar defer func() { if err != nil { + logger.Info("Error occurred during previewing stack. Setting stack sync state to preview failed") stackEntity.SyncState = constant.StackStatePreviewFailed m.stackRepo.Update(ctx, stackEntity) } else { @@ -149,13 +146,6 @@ func (m *StackManager) PreviewStack(ctx context.Context, params *StackRequestPar if err != nil { return nil, err } - releasePath := getReleasePath(stackEntity.Path, "default") - releaseStorage, err := stateBackend.StateStorageWithPath(releasePath) - if err != nil { - return nil, err - } - logger.Info("State storage found with path", "releasePath", releasePath) - // Get workspace configurations from backend wsStorage, err := stateBackend.WorkspaceStorage() if err != nil { @@ -165,15 +155,26 @@ func (m *StackManager) PreviewStack(ctx context.Context, params *StackRequestPar if err != nil { return nil, err } - // Checkout workdir - directory, workDir, err := GetWorkDirFromSource(ctx, stackEntity, stackEntity.Project) + + releasePath := getReleasePath(constant.DefaultReleaseNamespace, stackEntity.Project.Source.Name, stackEntity.Project.Path, ws.Name) + releaseStorage, err := stateBackend.StateStorageWithPath(releasePath) + if err != nil { + return nil, err + } + logger.Info("State storage found with path", "releasePath", releasePath) + + directory, workDir, err := m.GetWorkdirAndDirectory(ctx, params, stackEntity) if err != nil { return nil, err } stack.Path = workDir // Cleanup - defer sourceapi.Cleanup(ctx, directory) + defer func() { + if params.ExecuteParams.NoCache { + sourceapi.Cleanup(ctx, directory) + } + }() // Generate spec using default generator sp, err = engineapi.GenerateSpecWithSpinner(project, stack, ws, true) @@ -288,7 +289,7 @@ func (m *StackManager) ApplyStack(ctx context.Context, params *StackRequestParam } // create release - releasePath := getReleasePath(stackEntity.Path, "default") + releasePath := getReleasePath(constant.DefaultReleaseNamespace, stackEntity.Project.Source.Name, stackEntity.Project.Path, ws.Name) storage, err = stackBackend.StateStorageWithPath(releasePath) if err != nil { return err @@ -322,15 +323,19 @@ func (m *StackManager) ApplyStack(ctx context.Context, params *StackRequestParam executeOptions := BuildOptions(params.ExecuteParams.Dryrun, m.maxConcurrent) logger.Info("Previewing using the default generator ...") - // Checkout workdir - directory, workDir, err := GetWorkDirFromSource(ctx, stackEntity, stackEntity.Project) + + directory, workDir, err := m.GetWorkdirAndDirectory(ctx, params, stackEntity) if err != nil { return err } stack.Path = workDir // Cleanup - defer sourceapi.Cleanup(ctx, directory) + defer func() { + if params.ExecuteParams.NoCache { + sourceapi.Cleanup(ctx, directory) + } + }() // Generate spec using default generator sp, err = engineapi.GenerateSpecWithSpinner(project, stack, ws, true) @@ -388,8 +393,40 @@ func (m *StackManager) ApplyStack(ctx context.Context, params *StackRequestParam } executeOptions = BuildOptions(params.ExecuteParams.Dryrun, m.maxConcurrent) + + // Get graph storage directory, create if not exist + graphStorage, err := stackBackend.GraphStorage(project.Name, ws.Name) + if err != nil { + return err + } + + // Try to get existing graph, use the graph if exists + var gph *apiv1.Graph + if graphStorage.CheckGraphStorageExistence() { + gph, err = graphStorage.Get() + if err != nil { + return err + } + err = graph.ValidateGraph(gph) + if err != nil { + return err + } + // Put new resources from the generated spec to graph + gph, err = graph.GenerateGraph(sp.Resources, gph) + } else { + // Create a new graph to be used globally if no graph is stored in the storage + gph = &apiv1.Graph{ + Project: project.Name, + Workspace: ws.Name, + } + gph, err = graph.GenerateGraph(sp.Resources, gph) + } + if err != nil { + return err + } + var upRel *apiv1.Release - if upRel, err = engineapi.Apply(ctx, executeOptions, storage, rel, changes, os.Stdout); err != nil { + if upRel, err = engineapi.Apply(ctx, executeOptions, storage, rel, gph, changes, os.Stdout); err != nil { return err } rel = upRel @@ -461,7 +498,7 @@ func (m *StackManager) DestroyStack(ctx context.Context, params *StackRequestPar if err != nil { return err } - releasePath := getReleasePath(stackEntity.Path, "default") + releasePath := getReleasePath(constant.DefaultReleaseNamespace, stackEntity.Project.Source.Name, stackEntity.Project.Path, ws.Name) storage, err = stackBackend.StateStorageWithPath(releasePath) if err != nil { return err diff --git a/pkg/server/manager/stack/run.go b/pkg/server/manager/stack/run.go new file mode 100644 index 00000000..009c0331 --- /dev/null +++ b/pkg/server/manager/stack/run.go @@ -0,0 +1,137 @@ +package stack + +import ( + "context" + "errors" + "fmt" + + "github.com/jinzhu/copier" + "gorm.io/gorm" + + "kusionstack.io/kusion/pkg/domain/constant" + "kusionstack.io/kusion/pkg/domain/entity" + "kusionstack.io/kusion/pkg/domain/request" + + logutil "kusionstack.io/kusion/pkg/server/util/logging" +) + +func (m *StackManager) ListRuns(ctx context.Context, filter *entity.RunFilter) ([]*entity.Run, error) { + runEntities, err := m.runRepo.List(ctx, filter) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrGettingNonExistingStack + } + return nil, err + } + return runEntities, nil +} + +func (m *StackManager) GetRunByID(ctx context.Context, id uint) (*entity.Run, error) { + existingEntity, err := m.runRepo.Get(ctx, id) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrGettingNonExistingStack + } + return nil, err + } + return existingEntity, nil +} + +func (m *StackManager) DeleteRunByID(ctx context.Context, id uint) error { + err := m.runRepo.Delete(ctx, id) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrGettingNonExistingStack + } + return err + } + return nil +} + +func (m *StackManager) UpdateRunByID(ctx context.Context, id uint, requestPayload request.UpdateRunRequest) (*entity.Run, error) { + // Convert request payload to domain model + var requestEntity entity.Run + if err := copier.Copy(&requestEntity, &requestPayload); err != nil { + return nil, err + } + + // Get the existing stack by id + updatedEntity, err := m.runRepo.Get(ctx, id) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrUpdatingNonExistingStack + } + return nil, err + } + + // Overwrite non-zero values in request entity to existing entity + copier.CopyWithOption(updatedEntity, requestEntity, copier.Option{IgnoreEmpty: true}) + + // Update stack with repository + err = m.runRepo.Update(ctx, updatedEntity) + if err != nil { + return nil, err + } + return updatedEntity, nil +} + +func (m *StackManager) CreateRun(ctx context.Context, requestPayload request.CreateRunRequest) (*entity.Run, error) { + logger := logutil.GetLogger(ctx) + // Convert request payload to domain model + var createdEntity entity.Run + err := copier.Copy(&createdEntity, &requestPayload) + if err != nil { + return nil, err + } + + var stackEntity *entity.Stack + if requestPayload.StackID != 0 { + // If stack id is provided, get stack by id + logger.Info("Stack ID provided, getting stack by ID...", "stackID", requestPayload.StackID) + stackEntity, err = m.stackRepo.Get(ctx, requestPayload.StackID) + if err != nil { + return nil, err + } + createdEntity.Stack = stackEntity + } + + logger.Info("Creating new run for stack and workspace", "stack", fmt.Sprint(createdEntity.Stack.ID), "workspace", createdEntity.Workspace) + + // The default status is InProgress + createdEntity.Status = constant.RunStatusInProgress + // Create run with repository + err = m.runRepo.Create(ctx, &createdEntity) + if err != nil && err == gorm.ErrDuplicatedKey { + return nil, constant.ErrStackAlreadyExists + } else if err != nil { + return nil, err + } + return &createdEntity, nil +} + +func (m *StackManager) UpdateRunResultAndStatusByID(ctx context.Context, id uint, requestPayload request.UpdateRunResultRequest) (*entity.Run, error) { + // Convert request payload to domain model + var requestEntity entity.Run + if err := copier.Copy(&requestEntity, &requestPayload); err != nil { + return nil, err + } + + // Get the existing stack by id + updatedEntity, err := m.runRepo.Get(ctx, id) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrUpdatingNonExistingStack + } + return nil, err + } + + // Overwrite non-zero values in request entity to existing entity + copier.CopyWithOption(updatedEntity, requestEntity, copier.Option{IgnoreEmpty: true}) + + // Update stack with repository + err = m.runRepo.Update(ctx, updatedEntity) + if err != nil { + return nil, err + } + return updatedEntity, nil +} diff --git a/pkg/server/manager/stack/stack_test.go b/pkg/server/manager/stack/stack_test.go index 43d33c63..fa7fff78 100644 --- a/pkg/server/manager/stack/stack_test.go +++ b/pkg/server/manager/stack/stack_test.go @@ -524,10 +524,11 @@ func TestNewStackManager(t *testing.T) { projectRepo := &mockProjectRepository{} workspaceRepo := &mockWorkspaceRepository{} resourceRepo := persistence.NewResourceRepository(fakeGDB) + runRepo := persistence.NewRunRepository(fakeGDB) defaultBackend := entity.Backend{} maxConcurrent := 10 - manager := NewStackManager(stackRepo, projectRepo, workspaceRepo, resourceRepo, defaultBackend, maxConcurrent) + manager := NewStackManager(stackRepo, projectRepo, workspaceRepo, resourceRepo, runRepo, defaultBackend, maxConcurrent) assert.NotNil(t, manager) assert.Equal(t, stackRepo, manager.stackRepo) diff --git a/pkg/server/manager/stack/types.go b/pkg/server/manager/stack/types.go index 781fa672..fccc99a1 100644 --- a/pkg/server/manager/stack/types.go +++ b/pkg/server/manager/stack/types.go @@ -3,8 +3,10 @@ package stack import ( "errors" + "kusionstack.io/kusion/pkg/domain/constant" "kusionstack.io/kusion/pkg/domain/entity" "kusionstack.io/kusion/pkg/domain/repository" + cache "kusionstack.io/kusion/pkg/server/util/cache" ) const ( @@ -23,6 +25,7 @@ var ( ErrDryrunDestroy = errors.New("dryrun-mode is enabled, no resources will be destroyed") ErrStackInOperation = errors.New("the stack is being operated by another request. Please wait until it is completed") ErrStackNotPreviewedYet = errors.New("the stack has not been previewed yet. Please generate and preview the stack first") + ErrInvalidRunID = errors.New("the run ID should be a uuid") ) type StackManager struct { @@ -30,8 +33,15 @@ type StackManager struct { projectRepo repository.ProjectRepository workspaceRepo repository.WorkspaceRepository resourceRepo repository.ResourceRepository + runRepo repository.RunRepository defaultBackend entity.Backend maxConcurrent int + repoCache *cache.Cache[uint, *StackCache] +} + +type StackCache struct { + LocalDirOnDisk string + StackPath string } type StackRequestParams struct { @@ -48,12 +58,18 @@ type StackExecuteParams struct { SpecID string Force bool ImportResources bool + NoCache bool +} + +type RunRequestParams struct { + RunID uint } func NewStackManager(stackRepo repository.StackRepository, projectRepo repository.ProjectRepository, workspaceRepo repository.WorkspaceRepository, resourceRepo repository.ResourceRepository, + runRepo repository.RunRepository, defaultBackend entity.Backend, maxConcurrent int, ) *StackManager { @@ -62,7 +78,9 @@ func NewStackManager(stackRepo repository.StackRepository, projectRepo: projectRepo, workspaceRepo: workspaceRepo, resourceRepo: resourceRepo, + runRepo: runRepo, defaultBackend: defaultBackend, maxConcurrent: maxConcurrent, + repoCache: cache.NewCache[uint, *StackCache](constant.RepoCacheTTL), } } diff --git a/pkg/server/manager/stack/util.go b/pkg/server/manager/stack/util.go index 2c3dea7f..0eb0ca98 100644 --- a/pkg/server/manager/stack/util.go +++ b/pkg/server/manager/stack/util.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/http" + "os" "path/filepath" "regexp" "strconv" @@ -60,6 +61,44 @@ func GetWorkDirFromSource(ctx context.Context, stack *entity.Stack, project *ent return directory, workDir, nil } +// GetWorkdirAndDirectory is a helper function to get the workdir and directory for a stack +func (m *StackManager) GetWorkdirAndDirectory(ctx context.Context, params *StackRequestParams, stackEntity *entity.Stack) (directory string, workDir string, err error) { + logger := logutil.GetLogger(ctx) + logger.Info("Getting workdir and directory...") + if params.ExecuteParams.NoCache { + // If noCache is set, checkout workdir + logger.Info("Stack not found in cache. Pulling repo and set cache...") + directory, workDir, err = GetWorkDirFromSource(ctx, stackEntity, stackEntity.Project) + if err != nil { + return "", "", err + } + sc := &StackCache{ + LocalDirOnDisk: directory, + StackPath: workDir, + } + m.repoCache.Set(stackEntity.ID, sc) + } else { + // If repoCacheEnv is set, use the cached directory. This takes precedence over the in-memory cache + repoCacheEnv := os.Getenv("KUSION_SERVER_REPO_CACHE") + if repoCacheEnv != "" { + logger.Info("Repo cache found in env var. Using cached directory...") + directory = repoCacheEnv + workDir = fmt.Sprintf("%s/%s", directory, stackEntity.Path) + } else { + // No env var found, check if stack is in cache + logger.Info("No repo cache found in env var. Checking cache...") + stackCache, cacheExists := m.repoCache.Get(stackEntity.ID) + if cacheExists { + // if found in repoCache, use the cached workDir and directory + logger.Info("Stack found in cache. Using cache...") + workDir = stackCache.StackPath + directory = stackCache.LocalDirOnDisk + } + } + } + return directory, workDir, nil +} + func ProcessChanges(ctx context.Context, w http.ResponseWriter, changes *models.Changes, format string, detail bool) (any, error) { logger := logutil.GetLogger(ctx) logger.Info("Starting previewing stack in StackManager ...") @@ -144,7 +183,7 @@ func (m *StackManager) metaHelper( } // Get workspace configurations from backend - // TODO: temporarily local for now, should be replaced by variable sets + // TODO: should be replaced by variable sets wsStorage, err := stackBackend.WorkspaceStorage() if err != nil { return nil, nil, nil, nil, nil, err @@ -228,6 +267,33 @@ func (m *StackManager) BuildStackFilter(ctx context.Context, orgIDParam, project return &filter, nil } +func (m *StackManager) BuildRunFilter(ctx context.Context, projectIDParam, stackIDParam, workspaceParam string) (*entity.RunFilter, error) { + logger := logutil.GetLogger(ctx) + logger.Info("Building run filter...") + filter := entity.RunFilter{} + if projectIDParam != "" { + // if project id is present, use project id + projectID, err := strconv.Atoi(projectIDParam) + if err != nil { + return nil, constant.ErrInvalidProjectID + } + filter.ProjectID = uint(projectID) + } + if stackIDParam != "" { + // if project id is present, use project id + stackID, err := strconv.Atoi(stackIDParam) + if err != nil { + return nil, constant.ErrInvalidStackID + } + filter.StackID = uint(stackID) + } + if workspaceParam != "" { + // if workspace is present, use workspace + filter.Workspace = workspaceParam + } + return &filter, nil +} + func (m *StackManager) ImportTerraformResourceID(ctx context.Context, sp *v1.Spec, importedResources map[string]string) { for k, res := range sp.Resources { // only for terraform resources @@ -333,8 +399,8 @@ func isKubernetesResource(resource *v1.Resource) bool { return resource.Type == v1.Kubernetes } -func getReleasePath(stackPath, namespace string) string { - return fmt.Sprintf("%s/%s", namespace, stackPath) +func getReleasePath(namespace, source, projectPath, workspace string) string { + return fmt.Sprintf("%s/%s/%s/%s", namespace, source, projectPath, workspace) } func isInRelease(release *v1.Release, id string) bool { diff --git a/pkg/server/middleware/logger.go b/pkg/server/middleware/logger.go index ef07a8df..2595d93e 100644 --- a/pkg/server/middleware/logger.go +++ b/pkg/server/middleware/logger.go @@ -1,10 +1,12 @@ package middleware import ( + "bytes" "context" "log/slog" "net/http" "os" + "path/filepath" "time" "github.com/go-chi/httplog/v2" @@ -12,12 +14,26 @@ import ( ) // APILoggerKey is a context key used for associating a logger with a request. -var APILoggerKey = &contextKey{"logger"} +var ( + APILoggerKey = &contextKey{"logger"} + RunLoggerKey = &contextKey{"runLogger"} + RunLoggerBufferKey = &contextKey{"runLoggerBuffer"} +) func InitLogger(logFilePath string, name string) *httplog.Logger { logWriter, err := os.OpenFile(logFilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o666) if err != nil { - klog.Fatalf("Failed to open log file: %v", err) + // if directory does not exist, try to create the directory + if os.IsNotExist(err) { + logFileParent := filepath.Dir(logFilePath) + klog.Infof("Log directory does not exist, trying to create the directory in %s", logFileParent) + err = os.MkdirAll(logFileParent, 0o755) + if err != nil { + klog.Fatalf("Failed to create log directory: %v", err) + } + } else { + klog.Fatalf("Failed to open log file: %v", err) + } } logger := httplog.NewLogger(name, httplog.Options{ LogLevel: slog.LevelInfo, @@ -32,6 +48,21 @@ func InitLogger(logFilePath string, name string) *httplog.Logger { return logger } +func InitLoggerBuffer(name string) (*httplog.Logger, *bytes.Buffer) { + var buffer bytes.Buffer + logger := httplog.NewLogger(name, httplog.Options{ + LogLevel: slog.LevelInfo, + Concise: true, + TimeFieldFormat: time.RFC3339, + Writer: &buffer, + RequestHeaders: true, + Trace: &httplog.TraceOptions{ + HeaderTrace: "x-kusion-trace", + }, + }) + return logger, &buffer +} + // APILoggerMiddleware injects a logger, configured with a request ID, // into the request context for use throughout the request's lifecycle. func APILoggerMiddleware(logFile string) func(http.Handler) http.Handler { @@ -42,7 +73,10 @@ func APILoggerMiddleware(logFile string) func(http.Handler) http.Handler { if requestID := GetTraceID(ctx); len(requestID) > 0 { // Set the output file for klog logger := InitLogger(logFile, requestID) + runLogger, logBuffer := InitLoggerBuffer(requestID) ctx = context.WithValue(ctx, APILoggerKey, logger) + ctx = context.WithValue(ctx, RunLoggerKey, runLogger) + ctx = context.WithValue(ctx, RunLoggerBufferKey, logBuffer) } // Continue serving the request with the new context. next.ServeHTTP(w, r.WithContext(ctx)) diff --git a/pkg/server/route/route.go b/pkg/server/route/route.go index 93c59853..a4a64af4 100644 --- a/pkg/server/route/route.go +++ b/pkg/server/route/route.go @@ -132,8 +132,9 @@ func setupRestAPIV1( backendRepo := persistence.NewBackendRepository(config.DB) resourceRepo := persistence.NewResourceRepository(config.DB) moduleRepo := persistence.NewModuleRepository(config.DB) + runRepo := persistence.NewRunRepository(config.DB) - stackManager := stackmanager.NewStackManager(stackRepo, projectRepo, workspaceRepo, resourceRepo, config.DefaultBackend, config.MaxConcurrent) + stackManager := stackmanager.NewStackManager(stackRepo, projectRepo, workspaceRepo, resourceRepo, runRepo, config.DefaultBackend, config.MaxConcurrent) sourceManager := sourcemanager.NewSourceManager(sourceRepo) organizationManager := organizationmanager.NewOrganizationManager(organizationRepo) backendManager := backendmanager.NewBackendManager(backendRepo) @@ -158,7 +159,7 @@ func setupRestAPIV1( logger.Error(err.Error(), "Error creating project handler...", "error", err) return } - stackHandler, err := stack.NewHandler(stackManager) + stackHandler, err := stack.NewHandler(stackManager, config.MaxAsyncConcurrent, config.MaxAsyncBuffer) if err != nil { logger.Error(err.Error(), "Error creating stack handler...", "error", err) return @@ -194,12 +195,24 @@ func setupRestAPIV1( r.Post("/", sourceHandler.CreateSource()) r.Get("/", sourceHandler.ListSources()) }) + r.Route("/runs", func(r chi.Router) { + r.Route("/{runID}", func(r chi.Router) { + r.Get("/", stackHandler.GetRun()) + r.Get("/result", stackHandler.GetRunResult()) + }) + // r.Post("/", backendHandler.CreateRun()) + r.Get("/", stackHandler.ListRuns()) + }) r.Route("/stacks", func(r chi.Router) { r.Route("/{stackID}", func(r chi.Router) { r.Post("/generate", stackHandler.GenerateStack()) + r.Post("/generate/async", stackHandler.GenerateStackAsync()) r.Post("/preview", stackHandler.PreviewStack()) + r.Post("/preview/async", stackHandler.PreviewStackAsync()) r.Post("/apply", stackHandler.ApplyStack()) + r.Post("/apply/async", stackHandler.ApplyStackAsync()) r.Post("/destroy", stackHandler.DestroyStack()) + r.Post("/destroy/async", stackHandler.DestroyStackAsync()) // r.Route("/variable", func(r chi.Router) { // r.Post("/", stackHandler.UpdateStackVariable()) // }) diff --git a/pkg/server/util/cache/cache.go b/pkg/server/util/cache/cache.go new file mode 100644 index 00000000..3e2f7428 --- /dev/null +++ b/pkg/server/util/cache/cache.go @@ -0,0 +1,79 @@ +// Copyright The Karpor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "sync" + "time" +) + +// Cache manages the caching of items based on keys with +// expiration time for cached items. +type Cache[K comparable, V any] struct { + cache map[K]*CacheItem[V] + mu sync.RWMutex + expiration time.Duration +} + +// CacheItem represents an item stored in the cache along with its expiration +// time. +type CacheItem[V any] struct { + Data V + ExpiryTime time.Time +} + +// NewCache creates a new Cache instance with a specified expiration time. +func NewCache[K comparable, V any](expiration time.Duration) *Cache[K, V] { + return &Cache[K, V]{ + cache: make(map[K]*CacheItem[V]), + expiration: expiration, + } +} + +// Get retrieves an item from the cache based on the provided key. It returns +// the data and a boolean indicating if the data exists and hasn't expired. +func (c *Cache[K, V]) Get(key K) (V, bool) { + c.mu.Lock() + defer c.mu.Unlock() + + item, exist := c.cache[key] + if !exist { + return zeroValue[V](), false + } + + if time.Now().After(item.ExpiryTime) { + delete(c.cache, key) + return zeroValue[V](), false + } + + return item.Data, true +} + +// Set adds or updates an item in the cache with the provided key and data. +func (c *Cache[K, V]) Set(key K, data V) { + c.mu.Lock() + defer c.mu.Unlock() + + c.cache[key] = &CacheItem[V]{ + Data: data, + ExpiryTime: time.Now().Add(c.expiration), + } +} + +// zeroValue returns the zero value of type V. +func zeroValue[V any]() V { + var zero V + return zero +} diff --git a/pkg/server/util/cache/cache_test.go b/pkg/server/util/cache/cache_test.go new file mode 100644 index 00000000..e57a1a35 --- /dev/null +++ b/pkg/server/util/cache/cache_test.go @@ -0,0 +1,93 @@ +// Copyright The Karpor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "sync" + "testing" + "time" +) + +const MockCacheValue = "test value" + +func TestCache_SetAndGet(t *testing.T) { + expiration := 100 * time.Millisecond + cache := NewCache[int, string](expiration) + + key := 42 + cache.Set(key, MockCacheValue) + + // Check if the value is retrieved correctly + retrievedValue, exists := cache.Get(key) + if !exists { + t.Errorf("Expected value '%s' to exist in cache, but it doesn't.", MockCacheValue) + } + if retrievedValue != MockCacheValue { + t.Errorf("Expected value '%s', got '%s'", MockCacheValue, retrievedValue) + } + + // Wait for the value to expire + time.Sleep(expiration + 50*time.Millisecond) + + // Check if the value is expired + _, exists = cache.Get(key) + if exists { + t.Error("Expected value to be expired, but it still exists in cache.") + } +} + +func TestCache_SetAndGet_Concurrent(t *testing.T) { + expiration := 100 * time.Millisecond + cache := NewCache[int, string](expiration) + + key := 42 + + // Concurrently set and get the value + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + cache.Set(key, MockCacheValue) + }() + + go func() { + defer wg.Done() + time.Sleep(50 * time.Millisecond) + retrievedValue, exists := cache.Get(key) + if !exists || retrievedValue != MockCacheValue { + t.Errorf("Concurrent Set/Get: Expected value '%s', got '%s'", MockCacheValue, retrievedValue) + } + }() + + wg.Wait() +} + +func TestCache_ExpiredKeyIsDeleted(t *testing.T) { + expiration := 100 * time.Millisecond + cache := NewCache[int, string](expiration) + + key := 42 + cache.Set(key, MockCacheValue) + + // Wait for the value to expire + time.Sleep(expiration + 50*time.Millisecond) + + // Access the expired key + _, exists := cache.Get(key) + if exists { + t.Error("Expected expired key to be automatically deleted from the cache.") + } +} diff --git a/pkg/server/util/logging/ctxutil.go b/pkg/server/util/logging/ctxutil.go index bed4cf4b..881e5b19 100644 --- a/pkg/server/util/logging/ctxutil.go +++ b/pkg/server/util/logging/ctxutil.go @@ -1,6 +1,7 @@ package util import ( + "bytes" "context" "github.com/go-chi/httplog/v2" @@ -19,3 +20,21 @@ func GetLogger(ctx context.Context) *httplog.Logger { return httplog.NewLogger("DefaultLogger") } + +// GetRunLogger returns the run logger from the given context. +func GetRunLogger(ctx context.Context) *httplog.Logger { + if logger, ok := ctx.Value(middleware.RunLoggerKey).(*httplog.Logger); ok { + return logger + } + + return httplog.NewLogger("DefaultRunLogger") +} + +// GetRunLoggerBuffer returns the run logger buffer from the given context. +func GetRunLoggerBuffer(ctx context.Context) *bytes.Buffer { + if buffer, ok := ctx.Value(middleware.RunLoggerBufferKey).(*bytes.Buffer); ok { + return buffer + } + + return &bytes.Buffer{} +}