Skip to content

Commit

Permalink
Rework record/replay to record at the database connection level.
Browse files Browse the repository at this point in the history
  • Loading branch information
peterallenwebb committed Jun 21, 2024
1 parent 267cf5e commit 2d0d2fa
Show file tree
Hide file tree
Showing 2 changed files with 191 additions and 51 deletions.
239 changes: 191 additions & 48 deletions dbt/adapters/record.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,210 @@
import dataclasses
from io import StringIO
import json
import re
from typing import Any, Optional, Mapping
from typing import Any, Optional, Mapping, List, Union, Iterable

from agate import Table
from dbt.adapters.contracts.connection import Connection

from dbt_common.events.contextvars import get_node_info
from dbt_common.record import Record, Recorder
from dbt_common.record import Record, Recorder, record_function

from dbt.adapters.contracts.connection import AdapterResponse

class RecordReplayHandle:
def __init__(self, native_handle: Any, connection: Connection) -> None:
self.native_handle = native_handle
self.connection = connection

def cursor(self):
# The native handle could be None if we are in replay mode, because no
# actual database access should be performed in that mode.
cursor = None if self.native_handle is None else self.native_handle.cursor()
return RecordReplayCursor(cursor, self.connection)


@dataclasses.dataclass
class QueryRecordParams:
sql: str
auto_begin: bool = False
fetch: bool = False
limit: Optional[int] = None
node_unique_id: Optional[str] = None

def __post_init__(self) -> None:
if self.node_unique_id is None:
node_info = get_node_info()
self.node_unique_id = node_info["unique_id"] if node_info else ""

@staticmethod
def _clean_up_sql(sql: str) -> str:
sql = re.sub(r"--.*?\n", "", sql) # Remove single-line comments (--)
sql = re.sub(r"/\*.*?\*/", "", sql, flags=re.DOTALL) # Remove multi-line comments (/* */)
return sql.replace(" ", "").replace("\n", "")

def _matches(self, other: "QueryRecordParams") -> bool:
return self.node_unique_id == other.node_unique_id and self._clean_up_sql(
self.sql
) == self._clean_up_sql(other.sql)
class CursorExecuteParams:
connection_name: str
operation: str
parameters: Union[Iterable[Any], Mapping[str, Any]]


class CursorExecuteRecord(Record):
params_cls = CursorExecuteParams
result_cls = None


Recorder.register_record_type(CursorExecuteRecord)


@dataclasses.dataclass
class QueryRecordResult:
adapter_response: Optional["AdapterResponse"]
table: Optional[Table]
class CursorFetchOneParams:
connection_name: str

def _to_dict(self) -> Any:
buf = StringIO()
self.table.to_json(buf) # type: ignore

return {
"adapter_response": self.adapter_response.to_dict(), # type: ignore
"table": buf.getvalue(),
}
@dataclasses.dataclass
class CursorFetchOneResult:
result: Any


class CursorFetchOneRecord(Record):
params_cls = CursorFetchOneParams
result_cls = CursorFetchOneResult


Recorder.register_record_type(CursorFetchOneRecord)


@dataclasses.dataclass
class CursorFetchManyParams:
connection_name: str


@dataclasses.dataclass
class CursorFetchManyResult:
results: List[Any]


class CursorFetchManyRecord(Record):
params_cls = CursorFetchManyParams
result_cls = CursorFetchManyResult


Recorder.register_record_type(CursorFetchManyRecord)


@dataclasses.dataclass
class CursorFetchAllParams:
connection_name: str


@dataclasses.dataclass
class CursorFetchAllResult:
results: List[Any]


class CursorFetchAllRecord(Record):
params_cls = CursorFetchAllParams
result_cls = CursorFetchAllResult


Recorder.register_record_type(CursorFetchAllRecord)


@dataclasses.dataclass
class CursorGetRowCountParams:
connection_name: str


@dataclasses.dataclass
class CursorGetRowCountResult:
rowcount: Optional[int]


class CursorGetRowCountRecord(Record):
params_cls = CursorGetRowCountParams
result_cls = CursorGetRowCountResult


Recorder.register_record_type(CursorGetRowCountRecord)


@dataclasses.dataclass
class CursorGetDescriptionParams:
connection_name: str


@dataclasses.dataclass
class RecordReplayColumn:
name: str
type_code: int
display_size: Optional[int]
internal_size: int
null_ok: Optional[bool]
precision: Optional[int]
scale: Optional[int]
table_column: Optional[int]
table_oid: Optional[int]


@dataclasses.dataclass
class CursorGetDescriptionResult:
columns: Iterable[Any]

def _to_dict(self) -> Any:
column_dicts = []

for c in self.columns:
column_dicts.append(
{
"name": c.name,
"type_code": c.type_code,
"display_size": c.display_size,
"internal_size": c.internal_size,
"null_ok": c.null_ok,
"precision": c.precision,
"scale": c.scale,
"table_column": c.table_column,
"table_oid": c.table_oid,
}
)

return {"columns": column_dicts}

@classmethod
def _from_dict(cls, dct: Mapping) -> "QueryRecordResult":
return QueryRecordResult(
adapter_response=AdapterResponse.from_dict(dct["adapter_response"]),
table=Table.from_object(json.loads(dct["table"])),
def _from_dict(cls, dct: Mapping) -> "CursorGetDescriptionResult":
columns = iter(
RecordReplayColumn(
c["name"],
c["type_code"],
c["display_size"],
c["internal_size"],
c["null_ok"],
c["precision"],
c["scale"],
c["table_column"],
c["table_oid"],
)
for c in dct["columns"]
)
return CursorGetDescriptionResult(tuple(columns))


class CursorGetDescriptionRecord(Record):
params_cls = CursorGetDescriptionParams
result_cls = CursorGetDescriptionResult


Recorder.register_record_type(CursorGetDescriptionRecord)


class RecordReplayCursor:
def __init__(self, native_cursor: Any, connection: Connection) -> None:
self.native_cursor = native_cursor
self.connection = connection

@record_function(CursorExecuteRecord, method=True, id_field_name="connection_name")
def execute(self, operation, parameters=None) -> None:
self.native_cursor.execute(operation, parameters)

@record_function(CursorFetchOneRecord, method=True, id_field_name="connection_name")
def fetchone(self) -> Any:
return self.native_cursor.fetchone()

@record_function(CursorFetchManyRecord, method=True, id_field_name="connection_name")
def fetchmany(self, size: int) -> Any:
return self.native_cursor.fetchmany(size)

@record_function(CursorFetchAllRecord, method=True, id_field_name="connection_name")
def fetchall(self) -> Any:
return self.native_cursor.fetchall()

class QueryRecord(Record):
params_cls = QueryRecordParams
result_cls = QueryRecordResult
@property
def connection_name(self) -> Optional[str]:
return self.connection.name

@property
@record_function(CursorGetRowCountRecord, method=True, id_field_name="connection_name")
def rowcount(self) -> int:
return self.native_cursor.rowcount

Recorder.register_record_type(QueryRecord)
@property
@record_function(CursorGetDescriptionRecord, method=True, id_field_name="connection_name")
def description(self) -> str:
return self.native_cursor.description
3 changes: 0 additions & 3 deletions dbt/adapters/sql/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from dbt_common.events.contextvars import get_node_info
from dbt_common.events.functions import fire_event
from dbt_common.exceptions import DbtInternalError, NotImplementedError
from dbt_common.record import record_function
from dbt_common.utils import cast_to_str

from dbt.adapters.base import BaseConnectionManager
Expand All @@ -20,7 +19,6 @@
SQLQuery,
SQLQueryStatus,
)
from dbt.adapters.record import QueryRecord

if TYPE_CHECKING:
import agate
Expand Down Expand Up @@ -142,7 +140,6 @@ def get_result_from_cursor(cls, cursor: Any, limit: Optional[int]) -> "agate.Tab

return table_from_data_flat(data, column_names)

@record_function(QueryRecord, method=True, tuple_result=True)
def execute(
self,
sql: str,
Expand Down

0 comments on commit 2d0d2fa

Please sign in to comment.