From 102e4de9ec9f23b2d345c8e2df4015097593149e Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Fri, 13 Dec 2024 12:52:28 -0800 Subject: [PATCH] Rename TranslationData to SourceMap (#8119) --- edb/pgsql/codegen.py | 46 ++++++++++++++++----------------- edb/server/compiler/compiler.py | 2 +- edb/server/compiler/dbstate.py | 6 ++--- edb/server/compiler/sql.py | 20 +++++++------- edb/server/dbview/dbview.pyx | 4 +-- edb/server/pgcon/pgcon.pyx | 20 +++++++------- edb/server/protocol/execute.pyx | 10 +++---- 7 files changed, 54 insertions(+), 54 deletions(-) diff --git a/edb/pgsql/codegen.py b/edb/pgsql/codegen.py index 9f9d0e7398f..a599f213b88 100644 --- a/edb/pgsql/codegen.py +++ b/edb/pgsql/codegen.py @@ -41,7 +41,7 @@ def generate( add_line_information: bool = False, pretty: bool = True, reordered: bool = False, - with_translation_data: bool = False, + with_source_map: bool = False, ) -> SQLSource: # Main entrypoint @@ -52,7 +52,7 @@ def generate( pretty=pretty, ), reordered=reordered, - with_translation_data=with_translation_data, + with_source_map=with_source_map, ) try: @@ -72,12 +72,12 @@ def generate( exceptions.add_context(err, ctx) raise err from error - if with_translation_data: - assert generator.translation_data + if with_source_map: + assert generator.source_map return SQLSource( text=generator.finish(), - translation_data=generator.translation_data, + source_map=generator.source_map, param_index=generator.param_index, ) @@ -113,18 +113,18 @@ def generate_ctes_source( return generator.finish() -class TranslationData: +class SourceMap: @abc.abstractmethod def translate(self, pos: int) -> int: ... @dataclasses.dataclass(kw_only=True) -class BaseTranslationData(TranslationData): +class BaseSourceMap(SourceMap): source_start: int output_start: int output_end: int | None = None - children: List[BaseTranslationData] = ( + children: List[BaseSourceMap] = ( dataclasses.field(default_factory=list)) def translate(self, pos: int) -> int: @@ -139,8 +139,8 @@ def translate(self, pos: int) -> int: @dataclasses.dataclass -class ChainedTranslationData(TranslationData): - parts: List[TranslationData] = ( +class ChainedSourceMap(SourceMap): + parts: List[SourceMap] = ( dataclasses.field(default_factory=list)) def translate(self, pos: int) -> int: @@ -153,7 +153,7 @@ def translate(self, pos: int) -> int: class SQLSource: text: str param_index: dict[int, list[int]] - translation_data: Optional[TranslationData] = None + source_map: Optional[SourceMap] = None class SQLSourceGenerator(codegen.SourceGenerator): @@ -161,7 +161,7 @@ def __init__( self, opts: codegen.Options, *, - with_translation_data: bool = False, + with_source_map: bool = False, reordered: bool = False, ): super().__init__( @@ -170,14 +170,14 @@ def __init__( pretty=opts.pretty, ) # params - self.with_translation_data: bool = with_translation_data + self.with_source_map: bool = with_source_map self.reordered = reordered # state self.param_index: collections.defaultdict[int, list[int]] = ( collections.defaultdict(list)) self.write_index: int = 0 - self.translation_data: Optional[BaseTranslationData] = None + self.source_map: Optional[BaseSourceMap] = None def write( self, @@ -190,20 +190,20 @@ def write( self.write_index += len(self.result[new]) def visit(self, node): # type: ignore - if self.with_translation_data: - translation_data = BaseTranslationData( + if self.with_source_map: + source_map = BaseSourceMap( source_start=node.span.start if node.span else 0, output_start=self.write_index, ) - old_top = self.translation_data - self.translation_data = translation_data + old_top = self.source_map + self.source_map = source_map super().visit(node) - if self.with_translation_data: - assert self.translation_data == translation_data - self.translation_data.output_end = self.write_index + if self.with_source_map: + assert self.source_map == source_map + self.source_map.output_end = self.write_index if old_top: - old_top.children.append(self.translation_data) - self.translation_data = old_top + old_top.children.append(self.source_map) + self.source_map = old_top def generic_visit(self, node): # type: ignore raise GeneratorError( diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index 11d416603e8..20212325e4f 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -2711,7 +2711,7 @@ def compile_sql_as_unit_group( if sql_unit.cardinality is enums.Cardinality.NO_RESULT else enums.OutputFormat.BINARY ), - translation_data=sql_unit.translation_data, + source_map=sql_unit.source_map, sql_prefix_len=sql_unit.prefix_len, ) match sql_unit.tx_action: diff --git a/edb/server/compiler/dbstate.py b/edb/server/compiler/dbstate.py index 2da50815625..89ce1b61948 100644 --- a/edb/server/compiler/dbstate.py +++ b/edb/server/compiler/dbstate.py @@ -346,7 +346,7 @@ class QueryUnit: append_tx_op: bool = False # Translation source map. - translation_data: Optional[pgcodegen.TranslationData] = None + source_map: Optional[pgcodegen.SourceMap] = None # For SQL queries, the length of the query prefix applied # after translation. sql_prefix_len: int = 0 @@ -499,7 +499,7 @@ class PrepareData(PreparedStmtOpData): query: str """Translated query string""" - translation_data: Optional[pgcodegen.TranslationData] = None + source_map: Optional[pgcodegen.SourceMap] = None """Translation source map""" @@ -523,7 +523,7 @@ class SQLQueryUnit: """Translated query text.""" prefix_len: int = 0 - translation_data: Optional[pgcodegen.TranslationData] = None + source_map: Optional[pgcodegen.SourceMap] = None """Translation source map.""" eql_format_query: Optional[str] = dataclasses.field( diff --git a/edb/server/compiler/sql.py b/edb/server/compiler/sql.py index b40ef91b9d6..0914cf62488 100644 --- a/edb/server/compiler/sql.py +++ b/edb/server/compiler/sql.py @@ -114,7 +114,7 @@ def _try(q: str) -> List[dbstate.SQLQueryUnit]: def _build_constant_extraction_map( src: pgast.Base, out: pgast.Base, -) -> pg_codegen.BaseTranslationData: +) -> pg_codegen.BaseSourceMap: """Traverse two ASTs in parallel and build a source map between them. The ASTs should *mostly* line up. When they don't, that is @@ -127,7 +127,7 @@ def _build_constant_extraction_map( "parse" phase, so we don't need to worry about it being reused with different constants. """ - tdata = pg_codegen.BaseTranslationData( + tdata = pg_codegen.BaseSourceMap( source_start=src.span.start if src.span else 0, # HACK: I don't know why, but this - 1 helps a lot. output_start=out.span.start - 1 if out.span else 0, @@ -350,7 +350,7 @@ def _compile_sql( stmt_name=stmt.name, be_stmt_name=mangled_stmt_name.encode("utf-8"), query=stmt_source.text, - translation_data=stmt_source.translation_data, + source_map=stmt_source.source_map, ) unit.command_complete_tag = dbstate.TagPlain(tag=b"PREPARE") track_stats = True @@ -409,11 +409,11 @@ def _compile_sql( stmt, schema, tx_state, opts ) unit.query = stmt_source.text - unit.translation_data = stmt_source.translation_data - if stmt_source.translation_data: - unit.translation_data = ( - pg_codegen.ChainedTranslationData([ - stmt_source.translation_data, + unit.source_map = stmt_source.source_map + if stmt_source.source_map: + unit.source_map = ( + pg_codegen.ChainedSourceMap([ + stmt_source.source_map, extract_data, ]) ) @@ -562,11 +562,11 @@ def resolve_query( disambiguate_column_names=opts.disambiguate_column_names, ) resolved = pg_resolver.resolve(stmt, schema, options) - source = pg_codegen.generate(resolved.ast, with_translation_data=True) + source = pg_codegen.generate(resolved.ast, with_source_map=True) if resolved.edgeql_output_format_ast is not None: edgeql_format_source = pg_codegen.generate( resolved.edgeql_output_format_ast, - with_translation_data=True, + with_source_map=True, ) else: edgeql_format_source = None diff --git a/edb/server/dbview/dbview.pyx b/edb/server/dbview/dbview.pyx index 84c88b0e7a3..b979e78e756 100644 --- a/edb/server/dbview/dbview.pyx +++ b/edb/server/dbview/dbview.pyx @@ -1468,8 +1468,8 @@ cdef class DatabaseConnectionView: ex.fields['P'] = str( int(ex.fields['P']) - query_unit.sql_prefix_len ) - if query_unit.translation_data: - ex._translation_data = query_unit.translation_data + if query_unit.source_map: + ex._source_map = query_unit.source_map raise diff --git a/edb/server/pgcon/pgcon.pyx b/edb/server/pgcon/pgcon.pyx index 203d3eae2aa..f93db056c74 100644 --- a/edb/server/pgcon/pgcon.pyx +++ b/edb/server/pgcon/pgcon.pyx @@ -2154,14 +2154,14 @@ cdef class PGConnection: msg_buf: WriteBuffer, query: bytes, pos_bytes: bytes, - translation_data: Optional[pg_codegen.TranslationData], + source_map: Optional[pg_codegen.SourceMap], offset: int = 0, ): - if translation_data: + if source_map: pos = int(pos_bytes.decode('utf8')) if offset > 0 or pos + offset > 0: pos += offset - pos = translation_data.translate(pos) + pos = source_map.translate(pos) # pg uses 1-based indexes pos += 1 pos_bytes = str(pos).encode('utf8') @@ -2181,17 +2181,17 @@ cdef class PGConnection: field_type = self.buffer.read_byte() if field_type == b'P': # Position if action.query_unit is None: - translation_data = None + source_map = None offset = 0 else: qu = action.query_unit - translation_data = qu.translation_data + source_map = qu.source_map offset = -qu.prefix_len self._write_error_position( msg_buf, action.args[0], self.buffer.read_null_str(), - translation_data, + source_map, offset, ) continue @@ -2234,21 +2234,21 @@ cdef class PGConnection: query_text = qu.query.encode("utf-8") if qu.prepare is not None: offset = -55 - translation_data = qu.prepare.translation_data + source_map = qu.prepare.source_map else: offset = 0 - translation_data = qu.translation_data + source_map = qu.source_map offset -= qu.prefix_len else: query_text = b"" - translation_data = None + source_map = None offset = 0 self._write_error_position( msg_buf, query_text, self.buffer.read_null_str(), - translation_data, + source_map, offset, ) else: diff --git a/edb/server/protocol/execute.pyx b/edb/server/protocol/execute.pyx index 1547d14a966..6f38e239055 100644 --- a/edb/server/protocol/execute.pyx +++ b/edb/server/protocol/execute.pyx @@ -380,7 +380,7 @@ async def execute( if query_unit.user_schema: if isinstance(ex, pgerror.BackendError): ex._user_schema = query_unit.user_schema - if query_unit.translation_data: + if query_unit.source_map: ex._from_sql = True dbv.on_error() @@ -574,7 +574,7 @@ async def execute_script( if isinstance(e, pgerror.BackendError): e._user_schema = dbv.get_user_schema_pickle() - if query_unit and query_unit.translation_data: + if query_unit and query_unit.source_map: e._from_sql = True if not in_tx and dbv.in_tx(): @@ -939,7 +939,7 @@ async def interpret_error( elif isinstance(exc, pgerror.BackendError): try: from_sql = getattr(exc, '_from_sql', False) - translation_data = getattr(exc, '_translation_data', None) + source_map = getattr(exc, '_source_map', None) fields = exc.fields static_exc = errormech.static_interpret_backend_error( @@ -978,12 +978,12 @@ async def interpret_error( exc = errors.ExecutionError(*exc.args) # Translate error position for SQL queries if we can - if translation_data and isinstance(exc, errors.EdgeDBError): + if source_map and isinstance(exc, errors.EdgeDBError): if 'P' in fields: exc.set_position( 0, 0, - translation_data.translate(int(fields['P'])), + source_map.translate(int(fields['P'])), None, )