Skip to content

Commit

Permalink
Rename TranslationData to SourceMap (#8119)
Browse files Browse the repository at this point in the history
  • Loading branch information
msullivan authored Dec 13, 2024
1 parent e80225e commit 102e4de
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 54 deletions.
46 changes: 23 additions & 23 deletions edb/pgsql/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def generate(
add_line_information: bool = False,
pretty: bool = True,
reordered: bool = False,
with_translation_data: bool = False,
with_source_map: bool = False,
) -> SQLSource:
# Main entrypoint

Expand All @@ -52,7 +52,7 @@ def generate(
pretty=pretty,
),
reordered=reordered,
with_translation_data=with_translation_data,
with_source_map=with_source_map,
)

try:
Expand All @@ -72,12 +72,12 @@ def generate(
exceptions.add_context(err, ctx)
raise err from error

if with_translation_data:
assert generator.translation_data
if with_source_map:
assert generator.source_map

return SQLSource(
text=generator.finish(),
translation_data=generator.translation_data,
source_map=generator.source_map,
param_index=generator.param_index,
)

Expand Down Expand Up @@ -113,18 +113,18 @@ def generate_ctes_source(
return generator.finish()


class TranslationData:
class SourceMap:
@abc.abstractmethod
def translate(self, pos: int) -> int:
...


@dataclasses.dataclass(kw_only=True)
class BaseTranslationData(TranslationData):
class BaseSourceMap(SourceMap):
source_start: int
output_start: int
output_end: int | None = None
children: List[BaseTranslationData] = (
children: List[BaseSourceMap] = (
dataclasses.field(default_factory=list))

def translate(self, pos: int) -> int:
Expand All @@ -139,8 +139,8 @@ def translate(self, pos: int) -> int:


@dataclasses.dataclass
class ChainedTranslationData(TranslationData):
parts: List[TranslationData] = (
class ChainedSourceMap(SourceMap):
parts: List[SourceMap] = (
dataclasses.field(default_factory=list))

def translate(self, pos: int) -> int:
Expand All @@ -153,15 +153,15 @@ def translate(self, pos: int) -> int:
class SQLSource:
text: str
param_index: dict[int, list[int]]
translation_data: Optional[TranslationData] = None
source_map: Optional[SourceMap] = None


class SQLSourceGenerator(codegen.SourceGenerator):
def __init__(
self,
opts: codegen.Options,
*,
with_translation_data: bool = False,
with_source_map: bool = False,
reordered: bool = False,
):
super().__init__(
Expand All @@ -170,14 +170,14 @@ def __init__(
pretty=opts.pretty,
)
# params
self.with_translation_data: bool = with_translation_data
self.with_source_map: bool = with_source_map
self.reordered = reordered

# state
self.param_index: collections.defaultdict[int, list[int]] = (
collections.defaultdict(list))
self.write_index: int = 0
self.translation_data: Optional[BaseTranslationData] = None
self.source_map: Optional[BaseSourceMap] = None

def write(
self,
Expand All @@ -190,20 +190,20 @@ def write(
self.write_index += len(self.result[new])

def visit(self, node): # type: ignore
if self.with_translation_data:
translation_data = BaseTranslationData(
if self.with_source_map:
source_map = BaseSourceMap(
source_start=node.span.start if node.span else 0,
output_start=self.write_index,
)
old_top = self.translation_data
self.translation_data = translation_data
old_top = self.source_map
self.source_map = source_map
super().visit(node)
if self.with_translation_data:
assert self.translation_data == translation_data
self.translation_data.output_end = self.write_index
if self.with_source_map:
assert self.source_map == source_map
self.source_map.output_end = self.write_index
if old_top:
old_top.children.append(self.translation_data)
self.translation_data = old_top
old_top.children.append(self.source_map)
self.source_map = old_top

def generic_visit(self, node): # type: ignore
raise GeneratorError(
Expand Down
2 changes: 1 addition & 1 deletion edb/server/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2711,7 +2711,7 @@ def compile_sql_as_unit_group(
if sql_unit.cardinality is enums.Cardinality.NO_RESULT
else enums.OutputFormat.BINARY
),
translation_data=sql_unit.translation_data,
source_map=sql_unit.source_map,
sql_prefix_len=sql_unit.prefix_len,
)
match sql_unit.tx_action:
Expand Down
6 changes: 3 additions & 3 deletions edb/server/compiler/dbstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ class QueryUnit:
append_tx_op: bool = False

# Translation source map.
translation_data: Optional[pgcodegen.TranslationData] = None
source_map: Optional[pgcodegen.SourceMap] = None
# For SQL queries, the length of the query prefix applied
# after translation.
sql_prefix_len: int = 0
Expand Down Expand Up @@ -499,7 +499,7 @@ class PrepareData(PreparedStmtOpData):

query: str
"""Translated query string"""
translation_data: Optional[pgcodegen.TranslationData] = None
source_map: Optional[pgcodegen.SourceMap] = None
"""Translation source map"""


Expand All @@ -523,7 +523,7 @@ class SQLQueryUnit:
"""Translated query text."""

prefix_len: int = 0
translation_data: Optional[pgcodegen.TranslationData] = None
source_map: Optional[pgcodegen.SourceMap] = None
"""Translation source map."""

eql_format_query: Optional[str] = dataclasses.field(
Expand Down
20 changes: 10 additions & 10 deletions edb/server/compiler/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _try(q: str) -> List[dbstate.SQLQueryUnit]:
def _build_constant_extraction_map(
src: pgast.Base,
out: pgast.Base,
) -> pg_codegen.BaseTranslationData:
) -> pg_codegen.BaseSourceMap:
"""Traverse two ASTs in parallel and build a source map between them.
The ASTs should *mostly* line up. When they don't, that is
Expand All @@ -127,7 +127,7 @@ def _build_constant_extraction_map(
"parse" phase, so we don't need to worry about it being reused
with different constants.
"""
tdata = pg_codegen.BaseTranslationData(
tdata = pg_codegen.BaseSourceMap(
source_start=src.span.start if src.span else 0,
# HACK: I don't know why, but this - 1 helps a lot.
output_start=out.span.start - 1 if out.span else 0,
Expand Down Expand Up @@ -350,7 +350,7 @@ def _compile_sql(
stmt_name=stmt.name,
be_stmt_name=mangled_stmt_name.encode("utf-8"),
query=stmt_source.text,
translation_data=stmt_source.translation_data,
source_map=stmt_source.source_map,
)
unit.command_complete_tag = dbstate.TagPlain(tag=b"PREPARE")
track_stats = True
Expand Down Expand Up @@ -409,11 +409,11 @@ def _compile_sql(
stmt, schema, tx_state, opts
)
unit.query = stmt_source.text
unit.translation_data = stmt_source.translation_data
if stmt_source.translation_data:
unit.translation_data = (
pg_codegen.ChainedTranslationData([
stmt_source.translation_data,
unit.source_map = stmt_source.source_map
if stmt_source.source_map:
unit.source_map = (
pg_codegen.ChainedSourceMap([
stmt_source.source_map,
extract_data,
])
)
Expand Down Expand Up @@ -562,11 +562,11 @@ def resolve_query(
disambiguate_column_names=opts.disambiguate_column_names,
)
resolved = pg_resolver.resolve(stmt, schema, options)
source = pg_codegen.generate(resolved.ast, with_translation_data=True)
source = pg_codegen.generate(resolved.ast, with_source_map=True)
if resolved.edgeql_output_format_ast is not None:
edgeql_format_source = pg_codegen.generate(
resolved.edgeql_output_format_ast,
with_translation_data=True,
with_source_map=True,
)
else:
edgeql_format_source = None
Expand Down
4 changes: 2 additions & 2 deletions edb/server/dbview/dbview.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1468,8 +1468,8 @@ cdef class DatabaseConnectionView:
ex.fields['P'] = str(
int(ex.fields['P']) - query_unit.sql_prefix_len
)
if query_unit.translation_data:
ex._translation_data = query_unit.translation_data
if query_unit.source_map:
ex._source_map = query_unit.source_map

raise

Expand Down
20 changes: 10 additions & 10 deletions edb/server/pgcon/pgcon.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2154,14 +2154,14 @@ cdef class PGConnection:
msg_buf: WriteBuffer,
query: bytes,
pos_bytes: bytes,
translation_data: Optional[pg_codegen.TranslationData],
source_map: Optional[pg_codegen.SourceMap],
offset: int = 0,
):
if translation_data:
if source_map:
pos = int(pos_bytes.decode('utf8'))
if offset > 0 or pos + offset > 0:
pos += offset
pos = translation_data.translate(pos)
pos = source_map.translate(pos)
# pg uses 1-based indexes
pos += 1
pos_bytes = str(pos).encode('utf8')
Expand All @@ -2181,17 +2181,17 @@ cdef class PGConnection:
field_type = self.buffer.read_byte()
if field_type == b'P': # Position
if action.query_unit is None:
translation_data = None
source_map = None
offset = 0
else:
qu = action.query_unit
translation_data = qu.translation_data
source_map = qu.source_map
offset = -qu.prefix_len
self._write_error_position(
msg_buf,
action.args[0],
self.buffer.read_null_str(),
translation_data,
source_map,
offset,
)
continue
Expand Down Expand Up @@ -2234,21 +2234,21 @@ cdef class PGConnection:
query_text = qu.query.encode("utf-8")
if qu.prepare is not None:
offset = -55
translation_data = qu.prepare.translation_data
source_map = qu.prepare.source_map
else:
offset = 0
translation_data = qu.translation_data
source_map = qu.source_map
offset -= qu.prefix_len
else:
query_text = b""
translation_data = None
source_map = None
offset = 0

self._write_error_position(
msg_buf,
query_text,
self.buffer.read_null_str(),
translation_data,
source_map,
offset,
)
else:
Expand Down
10 changes: 5 additions & 5 deletions edb/server/protocol/execute.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ async def execute(
if query_unit.user_schema:
if isinstance(ex, pgerror.BackendError):
ex._user_schema = query_unit.user_schema
if query_unit.translation_data:
if query_unit.source_map:
ex._from_sql = True

dbv.on_error()
Expand Down Expand Up @@ -574,7 +574,7 @@ async def execute_script(
if isinstance(e, pgerror.BackendError):
e._user_schema = dbv.get_user_schema_pickle()

if query_unit and query_unit.translation_data:
if query_unit and query_unit.source_map:
e._from_sql = True

if not in_tx and dbv.in_tx():
Expand Down Expand Up @@ -939,7 +939,7 @@ async def interpret_error(
elif isinstance(exc, pgerror.BackendError):
try:
from_sql = getattr(exc, '_from_sql', False)
translation_data = getattr(exc, '_translation_data', None)
source_map = getattr(exc, '_source_map', None)
fields = exc.fields

static_exc = errormech.static_interpret_backend_error(
Expand Down Expand Up @@ -978,12 +978,12 @@ async def interpret_error(
exc = errors.ExecutionError(*exc.args)

# Translate error position for SQL queries if we can
if translation_data and isinstance(exc, errors.EdgeDBError):
if source_map and isinstance(exc, errors.EdgeDBError):
if 'P' in fields:
exc.set_position(
0,
0,
translation_data.translate(int(fields['P'])),
source_map.translate(int(fields['P'])),
None,
)

Expand Down

0 comments on commit 102e4de

Please sign in to comment.