Skip to content

Commit

Permalink
Rework policies; Introduce separate sequence for destroy action; Remo…
Browse files Browse the repository at this point in the history
…ve references to NETWORK_RULE before dropping it
  • Loading branch information
littleK0i committed Oct 11, 2024
1 parent 3f1d601 commit 6660252
Show file tree
Hide file tree
Showing 58 changed files with 1,092 additions and 193 deletions.
24 changes: 16 additions & 8 deletions snowddl/app/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from snowddl.blueprint import Ident, ObjectType
from snowddl.config import SnowDDLConfig
from snowddl.engine import SnowDDLEngine
from snowddl.parser import default_parser_sequence, PermissionModelParser, PlaceholderParser
from snowddl.resolver import default_resolver_sequence
from snowddl.parser import default_parse_sequence, PermissionModelParser, PlaceholderParser
from snowddl.resolver import default_resolve_sequence, default_destroy_sequence
from snowddl.settings import SnowDDLSettings
from snowddl.version import __version__

Expand All @@ -24,8 +24,9 @@ class BaseApp:
application_name = "SnowDDL"
application_version = __version__

parser_sequence = default_parser_sequence
resolver_sequence = default_resolver_sequence
parse_sequence = default_parse_sequence
resolve_sequence = default_resolve_sequence
destroy_sequence = default_destroy_sequence

def __init__(self):
self.elapsed_timers = {}
Expand Down Expand Up @@ -167,7 +168,10 @@ def init_arguments_parser(self):
action="store_true",
)
parser.add_argument(
"--apply-all-policy", help="Additionally apply changes to all types of POLICIES", default=False, action="store_true"
"--apply-all-policy", help="Additionally apply changes for all types of policies", default=False, action="store_true"
)
parser.add_argument(
"--apply-account-level-policy", help="Additionally apply changes for ACCOUNT-level policies", default=False, action="store_true"
)
parser.add_argument(
"--apply-aggregation-policy",
Expand Down Expand Up @@ -322,7 +326,7 @@ def init_config(self):
exit(1)

# Blueprints
for parser_cls in self.parser_sequence:
for parser_cls in self.parse_sequence:
parser = parser_cls(config, self.config_path)
parser.load_blueprints()

Expand Down Expand Up @@ -361,12 +365,16 @@ def init_settings(self):
settings.execute_replace_table = True

if self.args.get("apply_all_policy"):
settings.execute_account_level_policy = True
settings.execute_aggregation_policy = True
settings.execute_masking_policy = True
settings.execute_projection_policy = True
settings.execute_row_access_policy = True
settings.execute_network_policy = True

if self.args.get("apply_account_level_policy"):
settings.execute_account_level_policy = True

if self.args.get("apply_aggregation_policy"):
settings.execute_aggregation_policy = True

Expand Down Expand Up @@ -494,7 +502,7 @@ def execute(self):
if not self.args.get("env_prefix") and not self.args.get("destroy_without_prefix"):
raise ValueError("Argument --env-prefix is required for [destroy] action")

for resolver_cls in self.resolver_sequence:
for resolver_cls in self.destroy_sequence:
with self.measure_elapsed_time(resolver_cls.__name__):
resolver = resolver_cls(self.engine)
resolver.destroy()
Expand All @@ -503,7 +511,7 @@ def execute(self):

self.engine.context.destroy_role_with_prefix()
else:
for resolver_cls in self.resolver_sequence:
for resolver_cls in self.resolve_sequence:
with self.measure_elapsed_time(resolver_cls.__name__):
resolver = resolver_cls(self.engine)
resolver.resolve()
Expand Down
13 changes: 7 additions & 6 deletions snowddl/app/singledb.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
SchemaObjectBlueprint,
)
from snowddl.config import SnowDDLConfig
from snowddl.parser import singledb_parser_sequence
from snowddl.resolver import singledb_resolver_sequence
from snowddl.parser import singledb_parse_sequence
from snowddl.resolver import singledb_resolve_sequence, singledb_destroy_sequence


class SingleDbApp(BaseApp):
parser_sequence = singledb_parser_sequence
resolver_sequence = singledb_resolver_sequence
parse_sequence = singledb_parse_sequence
resolve_sequence = singledb_resolve_sequence
destroy_sequence = singledb_destroy_sequence

def __init__(self):
self.config_db: Optional[DatabaseIdent] = None
Expand Down Expand Up @@ -286,15 +287,15 @@ def execute(self):
self.output_engine_context()

if self.args.get("action") == "destroy":
for resolver_cls in self.resolver_sequence:
for resolver_cls in self.destroy_sequence:
with self.measure_elapsed_time(resolver_cls.__name__):
resolver = resolver_cls(self.engine)
resolver.destroy()

error_count += len(resolver.errors)

else:
for resolver_cls in self.resolver_sequence:
for resolver_cls in self.resolve_sequence:
with self.measure_elapsed_time(resolver_cls.__name__):
resolver = resolver_cls(self.engine)
resolver.resolve()
Expand Down
2 changes: 2 additions & 0 deletions snowddl/blueprint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,12 @@
from .object_type import ObjectType
from .permission_model import PermissionModel, PermissionModelCreateGrant, PermissionModelFutureGrant, PermissionModelRuleset
from .reference import (
AbstractPolicyReference,
AggregationPolicyReference,
ForeignKeyReference,
IndexReference,
MaskingPolicyReference,
NetworkPolicyReference,
ProjectionPolicyReference,
RowAccessPolicyReference,
TagReference,
Expand Down
13 changes: 9 additions & 4 deletions snowddl/blueprint/blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@
)
from .object_type import ObjectType
from .reference import (
AbstractPolicyReference,
AggregationPolicyReference,
ForeignKeyReference,
IndexReference,
MaskingPolicyReference,
NetworkPolicyReference,
ProjectionPolicyReference,
RowAccessPolicyReference,
TagReference,
Expand Down Expand Up @@ -66,7 +68,7 @@ class AccountParameterBlueprint(AbstractBlueprint):

class AggregationPolicyBlueprint(SchemaObjectBlueprint):
body: str
references: List[AggregationPolicyReference]
references: List[AggregationPolicyReference] = []


class AlertBlueprint(SchemaObjectBlueprint):
Expand Down Expand Up @@ -202,13 +204,16 @@ class MaskingPolicyBlueprint(SchemaObjectBlueprint):
returns: DataType
body: str
exempt_other_policies: bool = False
references: List[MaskingPolicyReference]
references: List[MaskingPolicyReference] = []


class NetworkPolicyBlueprint(AbstractBlueprint):
full_name: AccountObjectIdent
allowed_network_rule_list: List[SchemaObjectIdent] = []
blocked_network_rule_list: List[SchemaObjectIdent] = []
allowed_ip_list: List[str] = []
blocked_ip_list: List[str] = []
references: List[NetworkPolicyReference] = []


class NetworkRuleBlueprint(SchemaObjectBlueprint):
Expand Down Expand Up @@ -264,7 +269,7 @@ class ProcedureBlueprint(SchemaObjectBlueprint):

class ProjectionPolicyBlueprint(SchemaObjectBlueprint):
body: str
references: List[ProjectionPolicyReference]
references: List[ProjectionPolicyReference] = []


class ResourceMonitorBlueprint(AbstractBlueprint):
Expand All @@ -277,7 +282,7 @@ class ResourceMonitorBlueprint(AbstractBlueprint):
class RowAccessPolicyBlueprint(SchemaObjectBlueprint):
arguments: List[NameWithType]
body: str
references: List[RowAccessPolicyReference]
references: List[RowAccessPolicyReference] = []


class SchemaBlueprint(AbstractBlueprint):
Expand Down
20 changes: 15 additions & 5 deletions snowddl/blueprint/reference.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from abc import ABC
from typing import List, Optional

from .ident import Ident, SchemaObjectIdent
from .ident import AbstractIdent, Ident, SchemaObjectIdent
from .object_type import ObjectType
from ..model import BaseModelWithConfig


class AggregationPolicyReference(BaseModelWithConfig):
class AbstractPolicyReference(BaseModelWithConfig, ABC):
pass


class AggregationPolicyReference(AbstractPolicyReference):
object_type: ObjectType
object_name: SchemaObjectIdent
columns: List[Ident]
Expand All @@ -22,19 +27,24 @@ class IndexReference(BaseModelWithConfig):
include: Optional[List[Ident]]


class MaskingPolicyReference(BaseModelWithConfig):
class MaskingPolicyReference(AbstractPolicyReference):
object_type: ObjectType
object_name: SchemaObjectIdent
columns: List[Ident]


class ProjectionPolicyReference(BaseModelWithConfig):
class NetworkPolicyReference(AbstractPolicyReference):
object_type: ObjectType
object_name: Optional[AbstractIdent]


class ProjectionPolicyReference(AbstractPolicyReference):
object_type: ObjectType
object_name: SchemaObjectIdent
column: Ident


class RowAccessPolicyReference(BaseModelWithConfig):
class RowAccessPolicyReference(AbstractPolicyReference):
object_type: ObjectType
object_name: SchemaObjectIdent
columns: List[Ident]
Expand Down
12 changes: 12 additions & 0 deletions snowddl/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from snowddl.blueprint import (
AbstractBlueprint,
AbstractIdentWithPrefix,
AbstractPolicyReference,
ObjectType,
PermissionModel,
PermissionModelRuleset,
Expand Down Expand Up @@ -108,6 +109,17 @@ def remove_blueprint(self, bp: AbstractBlueprint):

del self.blueprints[bp.__class__][str(bp.full_name)]

def add_policy_reference_to_blueprint(self, cls: Type[T_Blueprint], policy_name: AbstractIdentWithPrefix, ref: AbstractPolicyReference):
all_blueprints = self.blueprints.get(cls, {})

if "references" not in cls.model_fields:
raise ValueError(f"{cls.__name__} does not have field [references], probably not a policy")

if str(policy_name) not in all_blueprints:
raise ValueError(f"{cls.__name__} with name [{policy_name}] does not exist or was not defined yet")

all_blueprints[str(policy_name)].references.append(ref)

def add_error(self, path: Path, e: Exception):
self.errors.append(
{
Expand Down
28 changes: 17 additions & 11 deletions snowddl/parser/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ._parsed_file import ParsedFile
from .account_params import AccountParameterParser
from .account_policy import AccountPolicyParser
from .aggregation_policy import AggregationPolicyParser
from .alert import AlertParser
from .business_role import BusinessRoleParser
Expand Down Expand Up @@ -37,10 +38,17 @@
from .warehouse import WarehouseParser


default_parser_sequence = [
default_parse_sequence = [
AccountParameterParser,
# --
AggregationPolicyParser,
MaskingPolicyParser,
NetworkPolicyParser,
ProjectionPolicyParser,
RowAccessPolicyParser,
ResourceMonitorParser,
AccountPolicyParser,
# --
WarehouseParser,
DatabaseParser,
SchemaParser,
Expand All @@ -63,19 +71,21 @@
ViewParser,
PipeParser,
TaskParser,
AggregationPolicyParser,
MaskingPolicyParser,
ProjectionPolicyParser,
RowAccessPolicyParser,
AlertParser,
# --
OutboundShareParser,
TechnicalRoleParser,
BusinessRoleParser,
UserParser,
AlertParser,
]


singledb_parser_sequence = [
singledb_parse_sequence = [
AggregationPolicyParser,
MaskingPolicyParser,
ProjectionPolicyParser,
RowAccessPolicyParser,
# --
DatabaseParser,
SchemaParser,
SecretParser,
Expand All @@ -96,9 +106,5 @@
ViewParser,
PipeParser,
TaskParser,
AggregationPolicyParser,
MaskingPolicyParser,
ProjectionPolicyParser,
RowAccessPolicyParser,
AlertParser,
]
3 changes: 3 additions & 0 deletions snowddl/parser/account_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,7 @@ def process_account_params(self, f: ParsedFile):
comment=None,
)

if str(bp.full_name) == "NETWORK_POLICY":
raise ValueError("NETWORK_POLICY in account_params.yaml is no longer supported. Please use account_policy.yaml instead. Read more: https://docs.snowddl.com/breaking-changes-log/0.33.0-october-2024")

self.config.add_blueprint(bp)
37 changes: 37 additions & 0 deletions snowddl/parser/account_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from snowddl.blueprint import AccountObjectIdent, NetworkPolicyBlueprint, NetworkPolicyReference, ObjectType
from snowddl.parser.abc_parser import AbstractParser, ParsedFile

from typing import Type


# fmt: off
account_policy_json_schema = {
"type": "object",
"properties": {
"network_policy": {
"type": "string"
},
},
"additionalProperties": False
}
# fmt: on


class AccountPolicyParser(AbstractParser):
def load_blueprints(self):
self.parse_single_file(self.base_path / "account_policy.yaml", account_policy_json_schema, self.process_account_policy)

def process_account_policy(self, f: ParsedFile):
if self.env_prefix:
# Account-level policies are ignored with env_prefix is present
# Can only assign one account-level policy per account, no way around it
return

if f.params.get("network_policy"):
policy_name = AccountObjectIdent(self.env_prefix, f.params.get("network_policy"))

ref = NetworkPolicyReference(
object_type=ObjectType.ACCOUNT,
)

self.config.add_policy_reference_to_blueprint(NetworkPolicyBlueprint, policy_name, ref)
Loading

0 comments on commit 6660252

Please sign in to comment.