Skip to content

Commit

Permalink
feat: Enforce rules on cluster and updater IDs
Browse files Browse the repository at this point in the history
* Use endTime of window as last_updated_at in DB table

Signed-off-by: Mahendra Paipuri <[email protected]>
  • Loading branch information
mahendrapaipuri committed Jun 25, 2024
1 parent f2d98be commit f3253b3
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 13 deletions.
6 changes: 6 additions & 0 deletions pkg/api/base/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package base

import (
"fmt"
"regexp"
"time"

"github.com/alecthomas/kingpin/v2"
Expand Down Expand Up @@ -68,3 +69,8 @@ var (

// APIVersion sets the version of API in paths
const APIVersion = "v1"

// Cluster and Updater ID valid regex
var (
InvalidIDRegex = regexp.MustCompile("[^a-zA-Z0-9-_]")
)
9 changes: 5 additions & 4 deletions pkg/api/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ func (s *statsDB) getUnitStats(startTime, endTime time.Time) error {

// Insert data into DB
level.Debug(s.logger).Log("msg", "Executing SQL statements")
s.execStatements(sqlStmts, units, users, projects)
s.execStatements(sqlStmts, endTime, units, users, projects)
level.Debug(s.logger).Log("msg", "Finished executing SQL statements")

// Commit changes
Expand Down Expand Up @@ -583,6 +583,7 @@ func (s *statsDB) prepareStatements(tx *sql.Tx) (map[string]*sql.Stmt, error) {
// Insert unit stat into DB
func (s *statsDB) execStatements(
statements map[string]*sql.Stmt,
currentTime time.Time,
clusterUnits []models.ClusterUnits,
clusterUsers []models.ClusterUsers,
clusterProjects []models.ClusterProjects,
Expand Down Expand Up @@ -637,7 +638,7 @@ func (s *statsDB) execStatements(
sql.Named(base.UnitsDBTableStructFieldColNameMap["Tags"], unit.Tags),
sql.Named(base.UnitsDBTableStructFieldColNameMap["ignore"], ignore),
sql.Named(base.UnitsDBTableStructFieldColNameMap["numupdates"], 1),
sql.Named(base.UsageDBTableStructFieldColNameMap["lastupdatedat"], time.Now().Format(base.DatetimeLayout)),
sql.Named(base.UsageDBTableStructFieldColNameMap["lastupdatedat"], currentTime.Format(base.DatetimeLayout)),
); err != nil {
level.Error(s.logger).
Log("msg", "Failed to insert unit in DB", "cluster_id", cluster.Cluster.ID, "uuid", unit.UUID, "err", err)
Expand All @@ -658,7 +659,7 @@ func (s *statsDB) execStatements(
sql.Named(base.UsageDBTableStructFieldColNameMap["NumUnits"], unitIncr),
sql.Named(base.UsageDBTableStructFieldColNameMap["Project"], unit.Project),
sql.Named(base.UsageDBTableStructFieldColNameMap["Usr"], unit.Usr),
sql.Named(base.UsageDBTableStructFieldColNameMap["lastupdatedat"], time.Now().Format(base.DatetimeLayout)),
sql.Named(base.UsageDBTableStructFieldColNameMap["lastupdatedat"], currentTime.Format(base.DatetimeLayout)),
sql.Named(base.UnitsDBTableStructFieldColNameMap["TotalWallTime"], unit.TotalWallTime),
sql.Named(base.UnitsDBTableStructFieldColNameMap["TotalCPUTime"], unit.TotalCPUTime),
sql.Named(base.UnitsDBTableStructFieldColNameMap["TotalGPUTime"], unit.TotalGPUTime),
Expand Down Expand Up @@ -727,7 +728,7 @@ func (s *statsDB) execStatements(
if _, err = statements[base.AdminUsersDBTableName].Exec(
sql.Named(base.AdminUsersDBTableStructFieldColNameMap["Source"], source),
sql.Named(base.AdminUsersDBTableStructFieldColNameMap["Users"], s.admin.users[source]),
sql.Named(base.AdminUsersDBTableStructFieldColNameMap["LastUpdatedAt"], time.Now().Format(base.DatetimeLayout)),
sql.Named(base.AdminUsersDBTableStructFieldColNameMap["LastUpdatedAt"], currentTime.Format(base.DatetimeLayout)),
); err != nil {
level.Error(s.logger).
Log("msg", "Failed to update admin_users table in DB", "source", source, "err", err)
Expand Down
6 changes: 3 additions & 3 deletions pkg/api/db/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -508,8 +508,8 @@ func populateDBWithMockData(s *statsDB) {
if err != nil {
fmt.Println(err)
}
s.execStatements(stmtMap, mockUnitsOne, mockUsersOne, mockProjectsOne)
s.execStatements(stmtMap, mockUnitsTwo, nil, nil)
s.execStatements(stmtMap, time.Now(), mockUnitsOne, mockUsersOne, mockProjectsOne)
s.execStatements(stmtMap, time.Now(), mockUnitsTwo, nil, nil)
tx.Commit()
}

Expand Down Expand Up @@ -866,7 +866,7 @@ func TestUnitStatsDeleteOldUnits(t *testing.T) {
if err != nil {
t.Errorf("Failed to prepare SQL statements: %s", err)
}
s.execStatements(stmtMap, units, nil, nil)
s.execStatements(stmtMap, time.Now(), units, nil, nil)

// Now clean up DB for old units
err = s.purgeExpiredUnits(tx)
Expand Down
6 changes: 1 addition & 5 deletions pkg/api/db/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,8 @@ import (
func TestJobStatsDBPreparation(t *testing.T) {
tmpDir := t.TempDir()
statDBPath := filepath.Join(tmpDir, "stats.db")
s := &storageConfig{
dbPath: statDBPath,
}
j := statsDB{
logger: log.NewNopLogger(),
storage: s,
logger: log.NewNopLogger(),
}

// Test setupDB function
Expand Down
5 changes: 4 additions & 1 deletion pkg/api/resource/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,14 @@ func checkConfig(managers []string, config *Config[models.Cluster]) (map[string]
var configMap = make(map[string][]models.Cluster)
for i := 0; i < len(config.Clusters); i++ {
if slices.Contains(IDs, config.Clusters[i].ID) {
return nil, fmt.Errorf("duplicate ID found in resource managers config")
return nil, fmt.Errorf("duplicate ID found in clusters config")
}
if !slices.Contains(managers, config.Clusters[i].Manager) {
return nil, fmt.Errorf("unknown resource manager found in the config: %s", config.Clusters[i].Manager)
}
if base.InvalidIDRegex.MatchString(config.Clusters[i].ID) {
return nil, fmt.Errorf("invalid ID %s found in clusters config. It must contain only [a-zA-Z0-9-_]", config.Clusters[i].ID)
}
IDs = append(IDs, config.Clusters[i].ID)
configMap[config.Clusters[i].Manager] = append(configMap[config.Clusters[i].Manager], config.Clusters[i])
}
Expand Down
26 changes: 26 additions & 0 deletions pkg/api/resource/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,14 @@ clusters:
// Missing s in clusters
configFileTmpl = `
---
# %[1]s %[2]s
cluster:
- id: default`
case "malformed_2":
// Missing manager name
configFileTmpl = `
---
# %[1]s
clusters:
- id: default
web:
Expand All @@ -165,11 +167,22 @@ clusters:
// Duplicated IDs
configFileTmpl = `
---
# %[1]s
clusters:
- id: default
web:
url: %[2]s
- id: default
web:
url: %[2]s`
case "malformed_4":
// invalid ID
configFileTmpl = `
---
# %[1]s
clusters:
- id: defau!$lt
manager: slurm
web:
url: %[2]s`
}
Expand Down Expand Up @@ -219,6 +232,19 @@ func TestUnknownManagerConfig(t *testing.T) {
}
}

func TestInvalidIDManagerConfig(t *testing.T) {
// Make mock config
base.ConfigFilePath = mockConfig(t.TempDir(), "malformed_4", "")

cfg, err := managerConfig()
if err != nil {
t.Errorf("failed to create manager config: %s", err)
}
if _, err = checkConfig([]string{"slurm"}, cfg); err == nil {
t.Errorf("expected error due to invalid ID in config. Got none")
}
}

func TestDuplicatedIDsConfig(t *testing.T) {
// Make mock config
base.ConfigFilePath = mockConfig(t.TempDir(), "malformed_3", "")
Expand Down
3 changes: 3 additions & 0 deletions pkg/api/updater/updater.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ func checkConfig(updaters []string, config *Config[Instance]) (map[string][]Inst
if !slices.Contains(updaters, config.Instances[i].Updater) {
return nil, fmt.Errorf("unknown updater found in the config: %s", config.Instances[i].Updater)
}
if base.InvalidIDRegex.MatchString(config.Instances[i].ID) {
return nil, fmt.Errorf("invalid ID %s found in updaters config. It must contain only [a-zA-Z0-9-_]", config.Instances[i].ID)
}
IDs = append(IDs, config.Instances[i].ID)
configMap[config.Instances[i].Updater] = append(configMap[config.Instances[i].Updater], config.Instances[i])
}
Expand Down
26 changes: 26 additions & 0 deletions pkg/api/updater/updater_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ updaters:
avg_cpu_usage: foo
avg_cpu_mem_usage: foo
- id: default-1
updater: tsdb
web:
url: %[1]s
extra_config:
Expand All @@ -50,29 +51,41 @@ updaters:
// Missing s in tsbd_instances
configFileTmpl = `
---
# %[1]s %[2]s
updater:
- id: default
updater: tsdb`
case "malformed_2":
// Missing updater name
configFileTmpl = `
---
# %[1]s %[2]s
updaters:
- id: default`
case "malformed_3":
// Duplicated IDs
configFileTmpl = `
---
# %[1]s %[2]s
updaters:
- id: default
- id: default`
case "malformed_4":
// Unknown updater
configFileTmpl = `
---
# %[1]s %[2]s
updaters:
- id: default
updater: unknown`
case "malformed_5":
// invalid ID updater
configFileTmpl = `
---
# %[1]s %[2]s
updaters:
- id: defau%lt
updater: tsdb`
}

configFile := fmt.Sprintf(configFileTmpl, serverURL, "2m")
Expand Down Expand Up @@ -120,6 +133,19 @@ func TestUnknownUpdaterConfig(t *testing.T) {
}
}

func TestInvalidIDUpdaterConfig(t *testing.T) {
// Make mock config
base.ConfigFilePath = mockConfig(t.TempDir(), "malformed_5", "http://localhost:9090")

cfg, err := updaterConfig()
if err != nil {
t.Errorf("failed to created updater config: %s", err)
}
if _, err = checkConfig([]string{"tsdb"}, cfg); err == nil {
t.Errorf("expected error due to invalid ID in config. Got none")
}
}

func TestDuplicatedIDsConfig(t *testing.T) {
// Make mock config
base.ConfigFilePath = mockConfig(t.TempDir(), "malformed_3", "http://localhost:9090")
Expand Down

0 comments on commit f3253b3

Please sign in to comment.