Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #321: Always use a single transaction when changing grant #9

Merged
merged 2 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 72 additions & 53 deletions postgresql/resource_postgresql_grant.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ var objectTypes = map[string]string{
func resourcePostgreSQLGrant() *schema.Resource {
return &schema.Resource{
Create: PGResourceFunc(resourcePostgreSQLGrantCreate),
// Since all of this resource's arguments force a recreation
// there's no need for an Update function
// Update:
Update: PGResourceFunc(resourcePostgreSQLGrantUpdate),
Read: PGResourceFunc(resourcePostgreSQLGrantRead),
Delete: PGResourceFunc(resourcePostgreSQLGrantDelete),

Expand All @@ -57,46 +55,46 @@ func resourcePostgreSQLGrant() *schema.Resource {
Description: "The database to grant privileges on for this role",
},
"schema": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
Type: schema.TypeString,
Optional: true,
// ForceNew: true,
Description: "The database schema to grant privileges on for this role",
},
"object_type": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
Type: schema.TypeString,
Required: true,
// ForceNew: true,
ValidateFunc: validation.StringInSlice(allowedObjectTypes, false),
Description: "The PostgreSQL object type to grant the privileges on (one of: " + strings.Join(allowedObjectTypes, ", ") + ")",
},
"objects": {
Type: schema.TypeSet,
Optional: true,
ForceNew: true,
Type: schema.TypeSet,
Optional: true,
// ForceNew: true,
Elem: &schema.Schema{Type: schema.TypeString},
Set: schema.HashString,
Description: "The specific objects to grant privileges on for this role (empty means all objects of the requested type)",
},
"columns": {
Type: schema.TypeSet,
Optional: true,
ForceNew: true,
Type: schema.TypeSet,
Optional: true,
// ForceNew: true,
Elem: &schema.Schema{Type: schema.TypeString},
Set: schema.HashString,
Description: "The specific columns to grant privileges on for this role",
},
"privileges": {
Type: schema.TypeSet,
Required: true,
ForceNew: true,
Type: schema.TypeSet,
Required: true,
// ForceNew: true,
Elem: &schema.Schema{Type: schema.TypeString},
Set: schema.HashString,
Description: "The list of privileges to grant",
},
"with_grant_option": {
Type: schema.TypeBool,
Optional: true,
ForceNew: true,
Type: schema.TypeBool,
Optional: true,
// ForceNew: true,
Default: false,
Description: "Permit the grant recipient to grant it to others",
},
Expand Down Expand Up @@ -129,6 +127,10 @@ func resourcePostgreSQLGrantRead(db *DBConnection, d *schema.ResourceData) error
}

func resourcePostgreSQLGrantCreate(db *DBConnection, d *schema.ResourceData) error {
return resourcePostgreSQLGrantCreateOrUpdate(db, d, false)
}

func resourcePostgreSQLGrantCreateOrUpdate(db *DBConnection, d *schema.ResourceData, usePreviousForRevoke bool) error {
if err := validateFeatureSupport(db, d); err != nil {
return fmt.Errorf("feature is not supported: %v", err)
}
Expand Down Expand Up @@ -187,7 +189,7 @@ func resourcePostgreSQLGrantCreate(db *DBConnection, d *schema.ResourceData) err
// Revoke all privileges before granting otherwise reducing privileges will not work.
// We just have to revoke them in the same transaction so the role will not lost its
// privileges between the revoke and grant statements.
if err := revokeRolePrivileges(txn, d); err != nil {
if err := revokeRolePrivileges(txn, d, usePreviousForRevoke); err != nil {
return err
}
if err := grantRolePrivileges(txn, d); err != nil {
Expand All @@ -213,6 +215,10 @@ func resourcePostgreSQLGrantCreate(db *DBConnection, d *schema.ResourceData) err
return readRolePrivileges(txn, d)
}

func resourcePostgreSQLGrantUpdate(db *DBConnection, d *schema.ResourceData) error {
return resourcePostgreSQLGrantCreateOrUpdate(db, d, true)
}

func resourcePostgreSQLGrantDelete(db *DBConnection, d *schema.ResourceData) error {
if err := validateFeatureSupport(db, d); err != nil {
return fmt.Errorf("feature is not supported: %v", err)
Expand Down Expand Up @@ -243,7 +249,7 @@ func resourcePostgreSQLGrantDelete(db *DBConnection, d *schema.ResourceData) err
}

if err := withRolesGranted(txn, owners, func() error {
return revokeRolePrivileges(txn, d)
return revokeRolePrivileges(txn, d, false)
}); err != nil {
return err
}
Expand Down Expand Up @@ -589,40 +595,42 @@ func createGrantQuery(d *schema.ResourceData, privileges []string) string {
return query
}

func createRevokeQuery(d *schema.ResourceData) string {
type ResourceSchemGetter func(string) interface{}

func createRevokeQuery(getter ResourceSchemGetter) string {
var query string

switch strings.ToUpper(d.Get("object_type").(string)) {
switch strings.ToUpper(getter("object_type").(string)) {
case "DATABASE":
query = fmt.Sprintf(
"REVOKE ALL PRIVILEGES ON DATABASE %s FROM %s",
pq.QuoteIdentifier(d.Get("database").(string)),
pq.QuoteIdentifier(d.Get("role").(string)),
pq.QuoteIdentifier(getter("database").(string)),
pq.QuoteIdentifier(getter("role").(string)),
)
case "SCHEMA":
query = fmt.Sprintf(
"REVOKE ALL PRIVILEGES ON SCHEMA %s FROM %s",
pq.QuoteIdentifier(d.Get("schema").(string)),
pq.QuoteIdentifier(d.Get("role").(string)),
pq.QuoteIdentifier(getter("schema").(string)),
pq.QuoteIdentifier(getter("role").(string)),
)
case "FOREIGN_DATA_WRAPPER":
fdwName := d.Get("objects").(*schema.Set).List()[0]
fdwName := getter("objects").(*schema.Set).List()[0]
query = fmt.Sprintf(
"REVOKE ALL PRIVILEGES ON FOREIGN DATA WRAPPER %s FROM %s",
pq.QuoteIdentifier(fdwName.(string)),
pq.QuoteIdentifier(d.Get("role").(string)),
pq.QuoteIdentifier(getter("role").(string)),
)
case "FOREIGN_SERVER":
srvName := d.Get("objects").(*schema.Set).List()[0]
srvName := getter("objects").(*schema.Set).List()[0]
query = fmt.Sprintf(
"REVOKE ALL PRIVILEGES ON FOREIGN SERVER %s FROM %s",
pq.QuoteIdentifier(srvName.(string)),
pq.QuoteIdentifier(d.Get("role").(string)),
pq.QuoteIdentifier(getter("role").(string)),
)
case "COLUMN":
objects := d.Get("objects").(*schema.Set)
columns := d.Get("columns").(*schema.Set)
privileges := d.Get("privileges").(*schema.Set)
objects := getter("objects").(*schema.Set)
columns := getter("columns").(*schema.Set)
privileges := getter("privileges").(*schema.Set)
if privileges.Len() == 0 || columns.Len() == 0 {
// No privileges to revoke, so don't revoke anything
query = "SELECT NULL"
Expand All @@ -631,38 +639,38 @@ func createRevokeQuery(d *schema.ResourceData) string {
"REVOKE %s (%s) ON TABLE %s FROM %s",
setToPgIdentSimpleList(privileges),
setToPgIdentListWithoutSchema(columns),
setToPgIdentList(d.Get("schema").(string), objects),
pq.QuoteIdentifier(d.Get("role").(string)),
setToPgIdentList(getter("schema").(string), objects),
pq.QuoteIdentifier(getter("role").(string)),
)
}
case "TABLE", "SEQUENCE", "FUNCTION", "PROCEDURE", "ROUTINE":
objects := d.Get("objects").(*schema.Set)
privileges := d.Get("privileges").(*schema.Set)
objects := getter("objects").(*schema.Set)
privileges := getter("privileges").(*schema.Set)
if objects.Len() > 0 {
if privileges.Len() > 0 {
// Revoking specific privileges instead of all privileges
// to avoid messing with column level grants
query = fmt.Sprintf(
"REVOKE %s ON %s %s FROM %s",
setToPgIdentSimpleList(privileges),
strings.ToUpper(d.Get("object_type").(string)),
setToPgIdentList(d.Get("schema").(string), objects),
pq.QuoteIdentifier(d.Get("role").(string)),
strings.ToUpper(getter("object_type").(string)),
setToPgIdentList(getter("schema").(string), objects),
pq.QuoteIdentifier(getter("role").(string)),
)
} else {
query = fmt.Sprintf(
"REVOKE ALL PRIVILEGES ON %s %s FROM %s",
strings.ToUpper(d.Get("object_type").(string)),
setToPgIdentList(d.Get("schema").(string), objects),
pq.QuoteIdentifier(d.Get("role").(string)),
strings.ToUpper(getter("object_type").(string)),
setToPgIdentList(getter("schema").(string), objects),
pq.QuoteIdentifier(getter("role").(string)),
)
}
} else {
query = fmt.Sprintf(
"REVOKE ALL PRIVILEGES ON ALL %sS IN SCHEMA %s FROM %s",
strings.ToUpper(d.Get("object_type").(string)),
pq.QuoteIdentifier(d.Get("schema").(string)),
pq.QuoteIdentifier(d.Get("role").(string)),
strings.ToUpper(getter("object_type").(string)),
pq.QuoteIdentifier(getter("schema").(string)),
pq.QuoteIdentifier(getter("role").(string)),
)
}
}
Expand All @@ -675,24 +683,35 @@ func grantRolePrivileges(txn *sql.Tx, d *schema.ResourceData) error {
for _, priv := range d.Get("privileges").(*schema.Set).List() {
privileges = append(privileges, priv.(string))
}

if len(privileges) == 0 {
log.Printf("[DEBUG] no privileges to grant for role %s in database: %s,", d.Get("role").(string), d.Get("database"))
return nil
}

query := createGrantQuery(d, privileges)
log.Printf("[INFO] executing %s", query)

_, err := txn.Exec(query)
return err
}

func revokeRolePrivileges(txn *sql.Tx, d *schema.ResourceData) error {
query := createRevokeQuery(d)
func revokeRolePrivileges(txn *sql.Tx, d *schema.ResourceData, usePrevious bool) error {
var getter ResourceSchemGetter
if usePrevious {
getter = func(name string) interface{} {
old, _ := d.GetChange(name)
return old
}
} else {
getter = func(name string) interface{} {
return d.Get(name)
}
}
query := createRevokeQuery(getter)
if len(query) == 0 {
// Query is empty, don't run anything
return nil
}
log.Printf("[INFO] executing %s", query)
if _, err := txn.Exec(query); err != nil {
return fmt.Errorf("could not execute revoke query: %w", err)
}
Expand Down
5 changes: 4 additions & 1 deletion postgresql/resource_postgresql_grant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,10 @@ func TestCreateRevokeQuery(t *testing.T) {
}

for _, c := range cases {
out := createRevokeQuery(c.resource)
getter := func(name string) interface{} {
return c.resource.Get(name)
}
out := createRevokeQuery(getter)
if out != c.expected {
t.Fatalf("Error matching output and expected: %#v vs %#v", out, c.expected)
}
Expand Down
Loading