Skip to content

Commit

Permalink
wip: sanction check configs.
Browse files Browse the repository at this point in the history
  • Loading branch information
Antoine Popineau committed Jan 16, 2025
1 parent 9f2e035 commit 8bc794e
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 46 deletions.
46 changes: 33 additions & 13 deletions dto/scenario_iterations.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,18 @@ type ScenarioIterationDto struct {
}

type ScenarioIterationBodyDto struct {
TriggerConditionAstExpression *NodeDto `json:"trigger_condition_ast_expression"`
Rules []RuleDto `json:"rules"`
ScoreReviewThreshold *int `json:"score_review_threshold"`
ScoreBlockAndReviewThreshold *int `json:"score_block_and_review_threshold"`
ScoreRejectThreshold_deprec *int `json:"score_reject_threshold"` //nolint:tagliatelle
ScoreDeclineThreshold *int `json:"score_decline_threshold"`
Schedule string `json:"schedule"`
TriggerConditionAstExpression *NodeDto `json:"trigger_condition_ast_expression"`
Rules []RuleDto `json:"rules"`
SanctionCheckConfig *SanctionCheckConfig `json:"sanction_check_config,omitempty"`
ScoreReviewThreshold *int `json:"score_review_threshold"`
ScoreBlockAndReviewThreshold *int `json:"score_block_and_review_threshold"`
ScoreRejectThreshold_deprec *int `json:"score_reject_threshold"` //nolint:tagliatelle
ScoreDeclineThreshold *int `json:"score_decline_threshold"`
Schedule string `json:"schedule"`
}

type SanctionCheckConfig struct {
Enabled bool `json:"enabled"`
}

func AdaptScenarioIterationWithBodyDto(si models.ScenarioIteration) (ScenarioIterationWithBodyDto, error) {
Expand All @@ -41,6 +46,7 @@ func AdaptScenarioIterationWithBodyDto(si models.ScenarioIteration) (ScenarioIte
ScoreDeclineThreshold: si.ScoreDeclineThreshold,
Schedule: si.Schedule,
Rules: make([]RuleDto, len(si.Rules)),
SanctionCheckConfig: nil,
}
for i, rule := range si.Rules {
apiRule, err := AdaptRuleDto(rule)
Expand All @@ -50,6 +56,11 @@ func AdaptScenarioIterationWithBodyDto(si models.ScenarioIteration) (ScenarioIte
}
body.Rules[i] = apiRule
}
if si.SanctionCheckConfig != nil {
body.SanctionCheckConfig = &SanctionCheckConfig{
Enabled: si.SanctionCheckConfig.Enabled,
}
}

if si.TriggerConditionAstExpression != nil {
triggerDto, err := AdaptNodeDto(*si.TriggerConditionAstExpression)
Expand All @@ -75,26 +86,34 @@ func AdaptScenarioIterationWithBodyDto(si models.ScenarioIteration) (ScenarioIte
// Update iteration DTO
type UpdateScenarioIterationBody struct {
Body struct {
TriggerConditionAstExpression *NodeDto `json:"trigger_condition_ast_expression"`
ScoreReviewThreshold *int `json:"score_review_threshold,omitempty"`
ScoreBlockAndReviewThreshold *int `json:"score_block_and_review_threshold,omitempty"`
ScoreRejectThreshold_deprec *int `json:"score_reject_threshold,omitempty"` //nolint:tagliatelle
ScoreDeclineThreshold *int `json:"score_decline_threshold,omitempty"`
Schedule *string `json:"schedule"`
TriggerConditionAstExpression *NodeDto `json:"trigger_condition_ast_expression"`
SanctionCheckConfig *SanctionCheckConfig `json:"sanction_check_config"`
ScoreReviewThreshold *int `json:"score_review_threshold,omitempty"`
ScoreBlockAndReviewThreshold *int `json:"score_block_and_review_threshold,omitempty"`
ScoreRejectThreshold_deprec *int `json:"score_reject_threshold,omitempty"` //nolint:tagliatelle
ScoreDeclineThreshold *int `json:"score_decline_threshold,omitempty"`
Schedule *string `json:"schedule"`
} `json:"body,omitempty"`
}

func AdaptUpdateScenarioIterationInput(input UpdateScenarioIterationBody, iterationId string) (models.UpdateScenarioIterationInput, error) {
updateScenarioIterationInput := models.UpdateScenarioIterationInput{
Id: iterationId,
Body: models.UpdateScenarioIterationBody{
SanctionCheckConfig: nil,
ScoreReviewThreshold: input.Body.ScoreReviewThreshold,
ScoreBlockAndReviewThreshold: input.Body.ScoreBlockAndReviewThreshold,
ScoreDeclineThreshold: input.Body.ScoreDeclineThreshold,
Schedule: input.Body.Schedule,
},
}

if input.Body.SanctionCheckConfig != nil {
updateScenarioIterationInput.Body.SanctionCheckConfig = &models.SanctionCheckConfig{
Enabled: input.Body.SanctionCheckConfig.Enabled,
}
}

if input.Body.ScoreDeclineThreshold == nil {
updateScenarioIterationInput.Body.ScoreDeclineThreshold = input.Body.ScoreRejectThreshold_deprec
}
Expand All @@ -119,6 +138,7 @@ type CreateScenarioIterationBody struct {
Body *struct {
TriggerConditionAstExpression *NodeDto `json:"trigger_condition_ast_expression"`
Rules []CreateRuleInputBody `json:"rules"`
SanctionCheckConfig *SanctionCheckConfig `json:"sanction_check_config,omitempty"`
ScoreReviewThreshold *int `json:"score_review_threshold,omitempty"`
ScoreBlockAndReviewThreshold *int `json:"score_block_and_review_threshold,omitempty"`
ScoreRejectThreshold_deprec *int `json:"score_reject_threshold,omitempty"` //nolint:tagliatelle
Expand Down
6 changes: 6 additions & 0 deletions models/scenario_iterations.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type ScenarioIteration struct {
UpdatedAt time.Time
TriggerConditionAstExpression *ast.Node
Rules []Rule
SanctionCheckConfig *SanctionCheckConfig
ScoreReviewThreshold *int
ScoreBlockAndReviewThreshold *int
ScoreDeclineThreshold *int
Expand Down Expand Up @@ -46,8 +47,13 @@ type UpdateScenarioIterationInput struct {

type UpdateScenarioIterationBody struct {
TriggerConditionAstExpression *ast.Node
SanctionCheckConfig *SanctionCheckConfig
ScoreReviewThreshold *int
ScoreBlockAndReviewThreshold *int
ScoreDeclineThreshold *int
Schedule *string
}

type SanctionCheckConfig struct {
Enabled bool
}
27 changes: 27 additions & 0 deletions repositories/dbmodels/db_sanction_check_config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package dbmodels

import (
"time"

"github.com/checkmarble/marble-backend/models"
"github.com/checkmarble/marble-backend/utils"
)

const TABLE_SANCTION_CHECK_CONFIGS = "sanction_check_configs"

type DBSanctionCheckConfigs struct {
Id string `db:"id"`
ScenarioIterationId string `db:"scenario_iteration_id"`
Enabled bool `db:"enabled"`
UpdatedAt time.Time `db:"updated_at"`
}

var SanctionCheckConfigColumnList = utils.ColumnList[DBSanctionCheckConfigs]()

func AdaptSanctionCheckConfig(db DBSanctionCheckConfigs) (models.SanctionCheckConfig, error) {
scc := models.SanctionCheckConfig{
Enabled: db.Enabled,
}

return scc, nil
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
-- +goose Up
-- +goose StatementBegin

create table sanction_check_configs (
id uuid primary key default uuid_generate_v4(),
scenario_iteration_id uuid unique,
enabled boolean,
updated_at timestamp with time zone not null default CURRENT_TIMESTAMP,

constraint fk_scneario_iteration
foreign key (scenario_iteration_id)
references scenario_iterations (id)
);

-- +goose StatementEnd

-- +goose Down
-- +goose StatementBegin

drop table sanction_check_configs;

-- +goose StatementEnd
35 changes: 35 additions & 0 deletions repositories/sanction_check_config_repository.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package repositories

import (
"context"
"fmt"
"strings"

"github.com/Masterminds/squirrel"
"github.com/checkmarble/marble-backend/models"
"github.com/checkmarble/marble-backend/repositories/dbmodels"
)

func (repo *MarbleDbRepository) GetSanctionCheckConfig(ctx context.Context, exec Executor,
scenarioIterationId string,
) (models.SanctionCheckConfig, error) {
sql := NewQueryBuilder().
Select("*").From(dbmodels.TABLE_SANCTION_CHECK_CONFIGS).
Where(squirrel.Eq{"scenario_iteration_id": scenarioIterationId})

return SqlToModel(ctx, exec, sql, dbmodels.AdaptSanctionCheckConfig)
}

func (repo *MarbleDbRepository) UpdateSanctionCheckConfig(ctx context.Context, exec Executor,
scenarioIterationId string, sanctionCheckConfig models.SanctionCheckConfig,
) (models.SanctionCheckConfig, error) {
sql := NewQueryBuilder().
Insert(dbmodels.TABLE_SANCTION_CHECK_CONFIGS).
Columns("scenario_iteration_id", "enabled").
Values(scenarioIterationId, sanctionCheckConfig.Enabled).
Suffix("ON CONFLICT (scenario_iteration_id) DO UPDATE").
Suffix("SET enabled = EXCLUDED.enabled, updated_at = NOW()").
Suffix(fmt.Sprintf("RETURNING %s", strings.Join(dbmodels.SanctionCheckConfigColumnList, ",")))

return SqlToModel(ctx, exec, sql, dbmodels.AdaptSanctionCheckConfig)
}
16 changes: 16 additions & 0 deletions usecases/sanction_check_config_usecase.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package usecases

import (
"context"

"github.com/checkmarble/marble-backend/models"
"github.com/checkmarble/marble-backend/repositories"
)

type SanctionCheckConfigRepository interface {
GetSanctionCheckConfig(ctx context.Context, exec repositories.Executor, scenarioIterationId string) (models.SanctionCheckConfig, error)
UpdateSanctionCheckConfig(ctx context.Context, exec repositories.Executor,
scenarioIterationId string, sanctionCheckConfig models.SanctionCheckConfig) (models.SanctionCheckConfig, error)
}

// TODO: Will we have a usecase for sanction checks?
87 changes: 60 additions & 27 deletions usecases/scenario_iterations_usecase.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,13 @@ type IterationUsecaseRepository interface {
}

type ScenarioIterationUsecase struct {
repository IterationUsecaseRepository
enforceSecurity security.EnforceSecurityScenario
scenarioFetcher scenarios.ScenarioFetcher
validateScenarioIteration scenarios.ValidateScenarioIteration
executorFactory executor_factory.ExecutorFactory
transactionFactory executor_factory.TransactionFactory
repository IterationUsecaseRepository
sanctionCheckConfigRepository SanctionCheckConfigRepository
enforceSecurity security.EnforceSecurityScenario
scenarioFetcher scenarios.ScenarioFetcher
validateScenarioIteration scenarios.ValidateScenarioIteration
executorFactory executor_factory.ExecutorFactory
transactionFactory executor_factory.TransactionFactory
}

func (usecase *ScenarioIterationUsecase) ListScenarioIterations(
Expand Down Expand Up @@ -91,6 +92,18 @@ func (usecase *ScenarioIterationUsecase) GetScenarioIteration(ctx context.Contex
if err != nil {
return models.ScenarioIteration{}, err
}

scc, err := usecase.sanctionCheckConfigRepository.GetSanctionCheckConfig(ctx,
usecase.executorFactory.NewExecutor(), si.Id)

switch {
case err == nil:
si.SanctionCheckConfig = &scc
case !errors.Is(err, models.NotFoundError):
return models.ScenarioIteration{}, errors.Wrap(err,
"could not retrieve sanction check config from scenario iteration")
}

if err := usecase.enforceSecurity.ReadScenarioIteration(si); err != nil {
return models.ScenarioIteration{}, err
}
Expand Down Expand Up @@ -148,29 +161,49 @@ func (usecase *ScenarioIterationUsecase) CreateScenarioIteration(ctx context.Con
func (usecase *ScenarioIterationUsecase) UpdateScenarioIteration(ctx context.Context,
organizationId string, scenarioIteration models.UpdateScenarioIterationInput,
) (iteration models.ScenarioIteration, err error) {
exec := usecase.executorFactory.NewExecutor()
scenarioAndIteration, err := usecase.scenarioFetcher.FetchScenarioAndIteration(ctx, exec, scenarioIteration.Id)
updatedScenarioIteration, err := executor_factory.TransactionReturnValue(
ctx,
usecase.transactionFactory,
func(tx repositories.Transaction) (models.ScenarioIteration, error) {
scenarioAndIteration, err := usecase.scenarioFetcher.FetchScenarioAndIteration(ctx, tx, scenarioIteration.Id)
if err != nil {
return iteration, err
}
if err := usecase.enforceSecurity.UpdateScenario(scenarioAndIteration.Scenario); err != nil {
return iteration, err
}

body := scenarioIteration.Body
if body.Schedule != nil && *body.Schedule != "" {
gron := gronx.New()
ok := gron.IsValid(*body.Schedule)
if !ok {
return iteration, fmt.Errorf("invalid schedule: %w", models.BadParameterError)
}
}
if scenarioAndIteration.Iteration.Version != nil {
return iteration, errors.Wrap(
models.ErrScenarioIterationNotDraft,
fmt.Sprintf("iteration %s is not a draft", scenarioAndIteration.Iteration.Id),
)
}

if scenarioIteration.Body.SanctionCheckConfig != nil {
if _, err := usecase.sanctionCheckConfigRepository.UpdateSanctionCheckConfig(ctx, tx,
scenarioAndIteration.Iteration.Id, *scenarioIteration.Body.SanctionCheckConfig); err != nil {
return iteration, err
}
}

scenarioIteration, err := usecase.repository.UpdateScenarioIteration(ctx, tx, scenarioIteration)

return scenarioIteration, err
})
if err != nil {
return iteration, err
}
if err := usecase.enforceSecurity.UpdateScenario(scenarioAndIteration.Scenario); err != nil {
return iteration, err
}
body := scenarioIteration.Body
if body.Schedule != nil && *body.Schedule != "" {
gron := gronx.New()
ok := gron.IsValid(*body.Schedule)
if !ok {
return iteration, fmt.Errorf("invalid schedule: %w", models.BadParameterError)
}
}
if scenarioAndIteration.Iteration.Version != nil {
return iteration, errors.Wrap(
models.ErrScenarioIterationNotDraft,
fmt.Sprintf("iteration %s is not a draft", scenarioAndIteration.Iteration.Id),
)
return models.ScenarioIteration{}, err
}
return usecase.repository.UpdateScenarioIteration(ctx, exec, scenarioIteration)

return updatedScenarioIteration, nil
}

func (usecase *ScenarioIterationUsecase) CreateDraftFromScenarioIteration(
Expand Down
13 changes: 7 additions & 6 deletions usecases/usecases_with_creds.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,13 @@ func (usecases *UsecasesWithCreds) NewScenarioUsecase() ScenarioUsecase {

func (usecases *UsecasesWithCreds) NewScenarioIterationUsecase() ScenarioIterationUsecase {
return ScenarioIterationUsecase{
repository: &usecases.Repositories.MarbleDbRepository,
enforceSecurity: usecases.NewEnforceScenarioSecurity(),
scenarioFetcher: usecases.NewScenarioFetcher(),
validateScenarioIteration: usecases.NewValidateScenarioIteration(),
executorFactory: usecases.NewExecutorFactory(),
transactionFactory: usecases.NewTransactionFactory(),
repository: &usecases.Repositories.MarbleDbRepository,
sanctionCheckConfigRepository: &usecases.Repositories.MarbleDbRepository,
enforceSecurity: usecases.NewEnforceScenarioSecurity(),
scenarioFetcher: usecases.NewScenarioFetcher(),
validateScenarioIteration: usecases.NewValidateScenarioIteration(),
executorFactory: usecases.NewExecutorFactory(),
transactionFactory: usecases.NewTransactionFactory(),
}
}

Expand Down

0 comments on commit 8bc794e

Please sign in to comment.