Skip to content

Commit

Permalink
refactor, add doc strings
Browse files Browse the repository at this point in the history
  • Loading branch information
aljazerzen committed Jul 15, 2024
1 parent 4c81550 commit a5f305f
Showing 1 changed file with 37 additions and 24 deletions.
61 changes: 37 additions & 24 deletions edb/pgsql/resolver/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,18 +132,18 @@ def _pull_columns_from_table(

def compile_dml(stmt: pgast.Base, *, ctx: Context) -> None:
# extract all dml stmts
dml_stmts_sql = collect_dml_stmts(stmt)
dml_stmts_sql = _collect_dml_stmts(stmt)
if len(dml_stmts_sql) == 0:
return

# preprocess each SQL dml stmt into EdgeQL
stmts = [preprocess_insert_stmt(s, ctx=ctx) for s in dml_stmts_sql]
stmts = [_preprocess_insert_stmt(s, ctx=ctx) for s in dml_stmts_sql]

# merge EdgeQL stmts & compile to SQL
ctx.compiled_dml = compile_preprocessed_dml(stmts, ctx=ctx)
ctx.compiled_dml = _compile_preprocessed_dml(stmts, ctx=ctx)


def collect_dml_stmts(stmt: pgast.Base) -> List[pgast.InsertStmt]:
def _collect_dml_stmts(stmt: pgast.Base) -> List[pgast.InsertStmt]:
if not isinstance(stmt, pgast.Query):
return []

Expand Down Expand Up @@ -182,9 +182,15 @@ class PreprocessedDML:
early_result: context.CompiledDML


def preprocess_insert_stmt(
def _preprocess_insert_stmt(
stmt: pgast.InsertStmt, *, ctx: Context
) -> PreprocessedDML:
"""
Takes SQL INSERT query and produces an equivalent EdgeQL insert query
and a bunch of metadata needed to extract associated CTEs from result of the
EdgeQL compiler.
"""

# determine the subject object we are inserting into
assert isinstance(stmt.relation, pgast.RelRangeVar)
assert isinstance(stmt.relation.relation, pgast.Relation)
Expand All @@ -209,13 +215,13 @@ def preprocess_insert_stmt(
)

# handle DEFAULT and prepare the value relation
value_relation, expected_columns = preprocess_insert_value(
value_relation, expected_columns = _preprocess_insert_value(
stmt.select_stmt, stmt.ctes, expected_columns
)

# if we are sure that we are inserting a single row
# we can skip for-loops and iterators, which produces better SQL
is_value_single = has_at_most_one_row(stmt.select_stmt)
is_value_single = _has_at_most_one_row(stmt.select_stmt)

# prepare anchors for inserted value columns
value_name = ctx.alias_generator.get('ins_val')
Expand All @@ -240,11 +246,11 @@ def preprocess_insert_stmt(
value_columns = []
insert_shape = []
for expected_col in expected_columns:
ptr, ptr_name, is_link = get_pointer_for_column(expected_col, sub, ctx)
ptr, ptr_name, is_link = _get_pointer_for_column(expected_col, sub, ctx)
value_columns.append((expected_col, ptr))

# prepare the outputs of the source CTE
ptr_id = get_ptr_id(value_id, ptr, ctx)
ptr_id = _get_ptr_id(value_id, ptr, ctx)
output_var = pgast.ColumnRef(name=(ptr_name,), nullable=True)
if is_link:
value_rel.path_outputs[(ptr_id, 'identity')] = output_var
Expand All @@ -254,7 +260,7 @@ def preprocess_insert_stmt(

# prepare insert shape that will use the paths from source_outputs
insert_shape.append(
construct_insert_element_for_ptr(
_construct_insert_element_for_ptr(
value_ql,
ptr_name,
ptr,
Expand Down Expand Up @@ -295,7 +301,7 @@ def preprocess_insert_stmt(
if column.hidden:
continue

ptr, ptr_name, is_link = get_pointer_for_column(column, sub, ctx)
ptr, ptr_name, is_link = _get_pointer_for_column(column, sub, ctx)
select_shape.append(
qlast.ShapeElement(
expr=qlast.Path(steps=[qlast.Ptr(name=ptr_name)]),
Expand All @@ -319,7 +325,6 @@ def preprocess_insert_stmt(
value_relation_input=value_relation,
value_columns=value_columns,
value_iterator_name=value_iterator,

# these will be populated after compilation
output_ctes=[],
output_relation_name='',
Expand All @@ -328,7 +333,7 @@ def preprocess_insert_stmt(
)


def has_at_most_one_row(query: pgast.Query | None) -> bool:
def _has_at_most_one_row(query: pgast.Query | None) -> bool:
return isinstance(query, pgast.SelectStmt) and (
(query.values and len(query.values) == 1)
or (
Expand All @@ -338,7 +343,7 @@ def has_at_most_one_row(query: pgast.Query | None) -> bool:
)


def preprocess_insert_value(
def _preprocess_insert_value(
value_query: Optional[pgast.Query],
value_ctes: Optional[List[pgast.CommonTableExpr]],
expected_columns: List[context.Column],
Expand Down Expand Up @@ -393,9 +398,19 @@ def is_default(e: pgast.BaseExpr) -> bool:
return value_query, expected_columns


def compile_preprocessed_dml(
def _compile_preprocessed_dml(
stmts: List[PreprocessedDML], ctx: context.ResolverContextLevel
) -> Mapping[pgast.Query, context.CompiledDML]:
"""
Compiles *all* DML statements in the query.
Statements must already be preprocessed into equivalent EdgeQL statements.
Will merge the statements into one large shape of all DML queries and
compile that with a single invocation of EdgeQL compiler.
Returns mapping from the original SQL statement into CompiledDML.
"""

# merge params
singletons = set()
anchors: Dict[str, irast.PathId] = {}
Expand Down Expand Up @@ -508,9 +523,9 @@ def _merge_and_prepare_external_rels(
stmt_names: List[str],
) -> Tuple[
Dict[irast.PathId, Tuple[pgast.BaseRelation, Tuple[str, ...]]],
List[irast.MutatingStmt]
List[irast.MutatingStmt],
]:
# construct external rels
"""Construct external rels used for compiling all DML statements at once."""

# this should be straight-forward, but because we put DML into with
# bindings, ql compiler will put each binding into a separate namespace.
Expand Down Expand Up @@ -640,7 +655,7 @@ def resolve_InsertStmt(
ctx.ctes_buffer.extend(compiled_dml.output_ctes)

if stmt.returning_list:
res_query, res_table = returning_rows(
res_query, res_table = _resolve_returning_rows(
stmt.returning_list,
compiled_dml.output_relation_name,
compiled_dml.output_namespace,
Expand All @@ -657,15 +672,13 @@ def resolve_InsertStmt(
return res_query, res_table


def returning_rows(
def _resolve_returning_rows(
returning_list: List[pgast.ResTarget],
output_relation_name: str,
output_namespace: Mapping[str, pgast.BaseExpr],
subject_alias: Optional[str],
ctx: context.ResolverContextLevel,
) -> Tuple[pgast.Query, context.Table]:
# extract pointers to be used in returning columns

# relation that provides the values of inserted pointers
inserted_rvar_name = ctx.alias_generator.get('ins')
inserted_query = pgast.SelectStmt(
Expand Down Expand Up @@ -712,7 +725,7 @@ def returning_rows(
return returning_query, returning_table


def construct_insert_element_for_ptr(
def _construct_insert_element_for_ptr(
source_ql: qlast.PathElement,
ptr_name: str,
ptr: s_pointers.Pointer,
Expand Down Expand Up @@ -751,7 +764,7 @@ def construct_insert_element_for_ptr(
)


def get_pointer_for_column(
def _get_pointer_for_column(
col: context.Column,
subject_stype: s_objtypes.ObjectType,
ctx: context.ResolverContextLevel,
Expand All @@ -770,7 +783,7 @@ def get_pointer_for_column(
return ptr, ptr_name, is_link


def get_ptr_id(
def _get_ptr_id(
source_id: irast.PathId,
ptr: s_pointers.Pointer,
ctx: context.ResolverContextLevel,
Expand Down

0 comments on commit a5f305f

Please sign in to comment.