Skip to content

Commit

Permalink
Add Prepare to StateBase and update SQLair
Browse files Browse the repository at this point in the history
Add a Prepare function to StateBase which caches sqlair statements. This
allows for statement reuse rather than having to reprepare a statement
every time it is used.

The latest update to SQLair introduced a sqlair.ErrNoRows error when
GetAll finds no results, there are quite a few places in the code where
GetAll is used and this caused some error. Checks are put in to fix
this.
  • Loading branch information
Aflynn50 committed Mar 22, 2024
1 parent e6fd53b commit e7c5620
Show file tree
Hide file tree
Showing 23 changed files with 292 additions and 170 deletions.
12 changes: 5 additions & 7 deletions domain/annotation/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (st *State) GetAnnotations(ctx context.Context, id annotations.ID) (map[str
return nil, errors.Trace(err)
}

getAnnotationsStmt, err := sqlair.Prepare(getAnnotationsQuery, Annotation{}, sqlair.M{})
getAnnotationsStmt, err := st.Prepare(getAnnotationsQuery, Annotation{}, sqlair.M{})
if err != nil {
return nil, errors.Annotatef(err, "preparing get annotations query for ID: %q", id.Name)
}
Expand All @@ -63,7 +63,6 @@ func (st *State) getAnnotationsForModel(ctx context.Context, id annotations.ID,
err = db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error {
return tx.Query(ctx, getAnnotationsStmt).GetAll(&annotationsResults)
})

if err != nil {
if errors.Is(err, sqlair.ErrNoRows) {
// No errors, we return empty map if no annotation is found
Expand Down Expand Up @@ -95,7 +94,7 @@ func (st *State) getAnnotationsForID(ctx context.Context, id annotations.ID, get
if err != nil {
return nil, errors.Annotatef(err, "preparing get annotations query for ID: %q", id.Name)
}
kindQueryStmt, err := sqlair.Prepare(kindQuery, sqlair.M{})
kindQueryStmt, err := st.Prepare(kindQuery, sqlair.M{})
if err != nil {
return nil, errors.Annotatef(err, "preparing get annotations query for ID: %q", id.Name)
}
Expand Down Expand Up @@ -123,7 +122,6 @@ func (st *State) getAnnotationsForID(ctx context.Context, id annotations.ID, get
"uuid": uuid,
}).GetAll(&annotationsResults)
})

if err != nil {
if errors.Is(err, sqlair.ErrNoRows) {
// No errors, we return empty map if no annotation is found
Expand Down Expand Up @@ -170,11 +168,11 @@ func (st *State) SetAnnotations(ctx context.Context, id annotations.ID,
}

// Prepare sqlair statements
setAnnotationsStmt, err := sqlair.Prepare(setAnnotationsQuery, Annotation{}, sqlair.M{})
setAnnotationsStmt, err := st.Prepare(setAnnotationsQuery, Annotation{}, sqlair.M{})
if err != nil {
return errors.Annotatef(err, "preparing set annotations query for ID: %q", id.Name)
}
deleteAnnotationsStmt, err := sqlair.Prepare(deleteAnnotationsQuery, Annotation{}, sqlair.M{})
deleteAnnotationsStmt, err := st.Prepare(deleteAnnotationsQuery, Annotation{}, sqlair.M{})
if err != nil {
return errors.Annotatef(err, "preparing set annotations query for ID: %q", id.Name)
}
Expand Down Expand Up @@ -204,7 +202,7 @@ func (st *State) setAnnotationsForID(ctx context.Context, id annotations.ID,
if err != nil {
return errors.Annotatef(err, "preparing uuid retrieval query for ID: %q", id.Name)
}
kindQueryStmt, err := sqlair.Prepare(kindQuery, sqlair.M{})
kindQueryStmt, err := st.Prepare(kindQuery, sqlair.M{})
if err != nil {
return errors.Annotatef(err, "preparing uuid retrieval query for ID: %q", id.Name)
}
Expand Down
29 changes: 16 additions & 13 deletions domain/application/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (st *State) UpsertApplication(ctx context.Context, name string, units ...ap

appNameParam := sqlair.M{"name": name}
query := `SELECT &M.uuid FROM application WHERE name = $M.name`
queryStmt, err := sqlair.Prepare(query, sqlair.M{})
queryStmt, err := st.Prepare(query, sqlair.M{})
if err != nil {
return errors.Trace(err)
}
Expand All @@ -60,12 +60,12 @@ func (st *State) UpsertApplication(ctx context.Context, name string, units ...ap
INSERT INTO application (uuid, name, life_id)
VALUES ($M.application_uuid, $M.name, $M.life_id)
`
createApplicationStmt, err := sqlair.Prepare(createApplication, sqlair.M{})
createApplicationStmt, err := st.Prepare(createApplication, sqlair.M{})
if err != nil {
return errors.Trace(err)
}

upsertUnitFunc, err := upsertUnitFuncGetter()
upsertUnitFunc, err := st.upsertUnitFuncGetter()
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -117,19 +117,19 @@ func (st *State) DeleteApplication(ctx context.Context, name string) error {
appNameParam := sqlair.M{"name": name}

queryApplication := `SELECT &M.uuid FROM application WHERE name = $M.name`
queryApplicationStmt, err := sqlair.Prepare(queryApplication, sqlair.M{})
queryApplicationStmt, err := st.Prepare(queryApplication, sqlair.M{})
if err != nil {
return errors.Trace(err)
}

queryUnits := `SELECT count(*) AS &M.count FROM unit WHERE application_uuid = $M.application_uuid`
queryUnitsStmt, err := sqlair.Prepare(queryUnits, sqlair.M{})
queryUnitsStmt, err := st.Prepare(queryUnits, sqlair.M{})
if err != nil {
return errors.Trace(err)
}

deleteApplication := `DELETE FROM application WHERE name = $M.name`
deleteApplicationStmt, err := sqlair.Prepare(deleteApplication, sqlair.M{})
deleteApplicationStmt, err := st.Prepare(deleteApplication, sqlair.M{})
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -175,7 +175,7 @@ func (st *State) AddUnits(ctx context.Context, applicationName string, args ...a
return errors.Trace(err)
}

upsertUnitFunc, err := upsertUnitFuncGetter()
upsertUnitFunc, err := st.upsertUnitFuncGetter()
if err != nil {
return errors.Trace(err)
}
Expand All @@ -196,15 +196,15 @@ type upsertUnitFunc func(ctx context.Context, tx *sqlair.TX, appName string, par
// upsertUnitFuncGetter returns a function which can be called as many times
// as needed to add units, ensuring that statement preparation is only done once.
// TODO - this just creates a minimal row for now.
func upsertUnitFuncGetter() (upsertUnitFunc, error) {
func (st *State) upsertUnitFuncGetter() (upsertUnitFunc, error) {
query := `SELECT &M.uuid FROM unit WHERE unit_id = $M.name`
queryStmt, err := sqlair.Prepare(query, sqlair.M{})
queryStmt, err := st.Prepare(query, sqlair.M{})
if err != nil {
return nil, errors.Trace(err)
}

queryApplication := `SELECT &M.uuid FROM application WHERE name = $M.name`
queryApplicationStmt, err := sqlair.Prepare(queryApplication, sqlair.M{})
queryApplicationStmt, err := st.Prepare(queryApplication, sqlair.M{})
if err != nil {
return nil, errors.Trace(err)
}
Expand All @@ -213,13 +213,13 @@ func upsertUnitFuncGetter() (upsertUnitFunc, error) {
INSERT INTO unit (uuid, net_node_uuid, unit_id, life_id, application_uuid)
VALUES ($M.unit_uuid, $M.net_node_uuid, $M.unit_id, $M.life_id, $M.application_uuid)
`
createUnitStmt, err := sqlair.Prepare(createUnit, sqlair.M{})
createUnitStmt, err := st.Prepare(createUnit, sqlair.M{})
if err != nil {
return nil, errors.Trace(err)
}

createNode := `INSERT INTO net_node (uuid) VALUES ($M.net_node_uuid)`
createNodeStmt, err := sqlair.Prepare(createNode, sqlair.M{})
createNodeStmt, err := st.Prepare(createNode, sqlair.M{})
if err != nil {
return nil, errors.Trace(err)
}
Expand Down Expand Up @@ -292,7 +292,7 @@ func (st *State) StorageDefaults(ctx context.Context) (domainstorage.StorageDefa

attrs := []string{application.StorageDefaultBlockSourceKey, application.StorageDefaultFilesystemSourceKey}
attrsSlice := sqlair.S(transform.Slice(attrs, func(s string) any { return any(s) }))
stmt, err := sqlair.Prepare(`
stmt, err := st.Prepare(`
SELECT &KeyValue.* FROM model_config WHERE key IN ($S[:])
`, sqlair.S{}, KeyValue{})
if err != nil {
Expand All @@ -303,6 +303,9 @@ SELECT &KeyValue.* FROM model_config WHERE key IN ($S[:])
var values []KeyValue
err := tx.Query(ctx, stmt, attrsSlice).GetAll(&values)
if err != nil {
if errors.Is(err, sqlair.ErrNoRows) {
return nil
}
return fmt.Errorf("getting model config attrs for storage defaults: %w", err)
}

Expand Down
2 changes: 1 addition & 1 deletion domain/autocert/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (st *State) Get(ctx context.Context, name string) ([]byte, error) {
SELECT (name, data) AS (&Autocert.*)
FROM autocert_cache
WHERE name = $M.name`
s, err := sqlair.Prepare(q, Autocert{}, sqlair.M{})
s, err := st.Prepare(q, Autocert{}, sqlair.M{})
if err != nil {
return nil, errors.Annotatef(err, "preparing %q", q)
}
Expand Down
35 changes: 19 additions & 16 deletions domain/blockdevice/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ func (st *State) BlockDevices(ctx context.Context, machineId string) ([]blockdev
var result []blockdevice.BlockDevice
err = db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error {
var err error
result, err = loadBlockDevices(ctx, tx, machineId)
result, err = st.loadBlockDevices(ctx, tx, machineId)
return errors.Trace(err)
})
return result, errors.Trace(err)
}

func loadBlockDevices(ctx context.Context, tx *sqlair.TX, machineId string) ([]blockdevice.BlockDevice, error) {
func (st *State) loadBlockDevices(ctx context.Context, tx *sqlair.TX, machineId string) ([]blockdevice.BlockDevice, error) {
query := `
SELECT bd.* AS &BlockDevice.*,
bdl.* AS &DeviceLink.*,
Expand All @@ -71,7 +71,7 @@ WHERE machine.machine_id = $M.machine_id
sqlair.M{},
}

stmt, err := sqlair.Prepare(query, types...)
stmt, err := st.Prepare(query, types...)
if err != nil {
return nil, errors.Trace(err)
}
Expand All @@ -84,19 +84,22 @@ WHERE machine.machine_id = $M.machine_id
machineParam := sqlair.M{"machine_id": machineId}
err = tx.Query(ctx, stmt, machineParam).GetAll(&dbRows, &dbDeviceLinks, &dbFilesystemTypes)
if err != nil {
if errors.Is(err, sqlair.ErrNoRows) {
return nil, nil
}
return nil, errors.Annotatef(err, "loading block devices for machine %q", machineId)
}
result, _, err := dbRows.toBlockDevicesAndMachines(dbDeviceLinks, dbFilesystemTypes, nil)
return result, errors.Trace(err)
}

func getMachineInfo(ctx context.Context, tx *sqlair.TX, machineId string) (string, life.Life, error) {
func (st *State) getMachineInfo(ctx context.Context, tx *sqlair.TX, machineId string) (string, life.Life, error) {
q := `
SELECT machine.life_id AS &M.life_id, machine.uuid AS &M.machine_uuid
FROM machine
WHERE machine.machine_id = $M.machine_id
`
stmt, err := sqlair.Prepare(q, sqlair.M{})
stmt, err := st.Prepare(q, sqlair.M{})
if err != nil {
return "", 0, errors.Trace(err)
}
Expand Down Expand Up @@ -127,22 +130,22 @@ func (st *State) SetMachineBlockDevices(ctx context.Context, machineId string, d
}

err = db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error {
machineUUID, machineLife, err := getMachineInfo(ctx, tx, machineId)
machineUUID, machineLife, err := st.getMachineInfo(ctx, tx, machineId)
if err != nil {
return errors.Trace(err)
}
if machineLife == life.Dead {
return errors.Errorf("cannot update block devices on dead machine %q", machineId)
}
existing, err := loadBlockDevices(ctx, tx, machineId)
existing, err := st.loadBlockDevices(ctx, tx, machineId)
if err != nil {
return errors.Annotatef(err, "loading block devices for machine %q", machineId)
}
if !blockDevicesChanged(existing, devices) {
return nil
}

if err := updateBlockDevices(ctx, tx, machineUUID, devices...); err != nil {
if err := st.updateBlockDevices(ctx, tx, machineUUID, devices...); err != nil {
return errors.Annotatef(err, "updating block devices on machine %q (%s)", machineId, machineUUID)
}
return nil
Expand All @@ -151,7 +154,7 @@ func (st *State) SetMachineBlockDevices(ctx context.Context, machineId string, d
return errors.Trace(err)
}

func updateBlockDevices(ctx context.Context, tx *sqlair.TX, machineUUID string, devices ...blockdevice.BlockDevice) error {
func (st *State) updateBlockDevices(ctx context.Context, tx *sqlair.TX, machineUUID string, devices ...blockdevice.BlockDevice) error {
if err := RemoveMachineBlockDevices(ctx, tx, machineUUID); err != nil {
return errors.Annotatef(err, "removing existing block devices for machine %q", machineUUID)
}
Expand All @@ -161,12 +164,12 @@ func updateBlockDevices(ctx context.Context, tx *sqlair.TX, machineUUID string,
}

fsTypeQuery := `SELECT * AS &FilesystemType.* FROM filesystem_type`
fsTypeStmt, err := sqlair.Prepare(fsTypeQuery, FilesystemType{})
fsTypeStmt, err := st.Prepare(fsTypeQuery, FilesystemType{})
if err != nil {
return errors.Trace(err)
}
var fsTypes []FilesystemType
if err := tx.Query(ctx, fsTypeStmt).GetAll(&fsTypes); err != nil {
if err := tx.Query(ctx, fsTypeStmt).GetAll(&fsTypes); err != nil && !errors.Is(err, sqlair.ErrNoRows) {
return errors.Trace(err)
}
fsTypeByName := make(map[string]int)
Expand All @@ -192,7 +195,7 @@ VALUES (
$BlockDevice.in_use
)
`
insertStmt, err := sqlair.Prepare(insertQuery, BlockDevice{})
insertStmt, err := st.Prepare(insertQuery, BlockDevice{})
if err != nil {
return errors.Trace(err)
}
Expand All @@ -204,7 +207,7 @@ VALUES (
$DeviceLink.name
)
`
insertLinkStmt, err := sqlair.Prepare(insertLinkQuery, DeviceLink{})
insertLinkStmt, err := st.Prepare(insertLinkQuery, DeviceLink{})
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -297,7 +300,7 @@ FROM block_device bd
BlockDeviceMachine{},
}

stmt, err := sqlair.Prepare(query, types...)
stmt, err := st.Prepare(query, types...)
if err != nil {
return nil, errors.Trace(err)
}
Expand All @@ -313,7 +316,7 @@ FROM block_device bd
dbFilesystemTypes []FilesystemType
dbMachines []BlockDeviceMachine
)
if err := tx.Query(ctx, stmt).GetAll(&dbRows, &dbDeviceLinks, &dbFilesystemTypes, &dbMachines); err != nil {
if err := tx.Query(ctx, stmt).GetAll(&dbRows, &dbDeviceLinks, &dbFilesystemTypes, &dbMachines); err != nil && !errors.Is(err, sqlair.ErrNoRows) {
return errors.Annotate(err, "loading block devices")
}
blockDevices, machines, err = dbRows.toBlockDevicesAndMachines(dbDeviceLinks, dbFilesystemTypes, dbMachines)
Expand Down Expand Up @@ -399,7 +402,7 @@ func (st *State) WatchBlockDevices(
machineLife life.Life
)
err = db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error {
machineUUID, machineLife, err = getMachineInfo(ctx, tx, machineId)
machineUUID, machineLife, err = st.getMachineInfo(ctx, tx, machineId)
return errors.Trace(err)
})

Expand Down
17 changes: 10 additions & 7 deletions domain/cloud/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,12 @@ func (st *State) UpdateCloudDefaults(
return errors.Trace(err)
}

selectStmt, err := sqlair.Prepare("SELECT &Cloud.uuid FROM cloud WHERE name = $Cloud.name", Cloud{})
selectStmt, err := st.Prepare("SELECT &Cloud.uuid FROM cloud WHERE name = $Cloud.name", Cloud{})
if err != nil {
return errors.Trace(err)
}

deleteStmt, err := sqlair.Prepare(`
deleteStmt, err := st.Prepare(`
DELETE FROM cloud_defaults
WHERE key IN ($Attrs[:])
AND cloud_uuid = $Cloud.uuid;
Expand Down Expand Up @@ -290,7 +290,7 @@ func (st *State) CloudAllRegionDefaults(
return defaults, fmt.Errorf("getting database instance for cloud region defaults: %w", err)
}

stmt, err := sqlair.Prepare(`
stmt, err := st.Prepare(`
SELECT (cloud_region.name,
cloud_region_defaults.key,
cloud_region_defaults.value)
Expand All @@ -311,6 +311,9 @@ WHERE cloud.name = $Cloud.name
var regionDefaultValues []CloudRegionDefaultValue

if err := tx.Query(ctx, stmt, Cloud{Name: cloudName}).GetAll(&regionDefaultValues); err != nil {
if errors.Is(err, sqlair.ErrNoRows) {
return nil
}
return fmt.Errorf("fetching cloud %q region defaults: %w", cloudName, domain.CoerceError(err))
}

Expand Down Expand Up @@ -343,7 +346,7 @@ func (st *State) UpdateCloudRegionDefaults(
return errors.Trace(err)
}

selectStmt, err := sqlair.Prepare(`
selectStmt, err := st.Prepare(`
SELECT cloud_region.uuid AS &CloudRegion.uuid
FROM cloud_region
INNER JOIN cloud
Expand All @@ -355,7 +358,7 @@ AND cloud_region.name = $CloudRegion.name;
return errors.Trace(err)
}

deleteStmt, err := sqlair.Prepare(`
deleteStmt, err := st.Prepare(`
DELETE FROM cloud_region_defaults
WHERE key IN ($Attrs[:])
AND region_uuid = $CloudRegion.uuid;
Expand All @@ -364,7 +367,7 @@ AND region_uuid = $CloudRegion.uuid;
return errors.Trace(err)
}

upsertStmt, err := sqlair.Prepare(`
upsertStmt, err := st.Prepare(`
INSERT INTO cloud_region_defaults (region_uuid, key, value)
VALUES ($CloudRegionDefaults.region_uuid, $CloudRegionDefaults.key, $CloudRegionDefaults.value)
ON CONFLICT(region_uuid, key) DO UPDATE
Expand Down Expand Up @@ -606,7 +609,7 @@ func (st *State) UpsertCloud(ctx context.Context, cloud cloud.Cloud) error {
return errors.Trace(err)
}

selectUUIDStmt, err := sqlair.Prepare("SELECT &Cloud.uuid FROM cloud WHERE name = $Cloud.name", Cloud{})
selectUUIDStmt, err := st.Prepare("SELECT &Cloud.uuid FROM cloud WHERE name = $Cloud.name", Cloud{})
if err != nil {
return errors.Trace(domain.CoerceError(err))
}
Expand Down
Loading

0 comments on commit e7c5620

Please sign in to comment.