diff --git a/src/data_common/db/duck.py b/src/data_common/db/duck.py index d4dfbfb..6121705 100644 --- a/src/data_common/db/duck.py +++ b/src/data_common/db/duck.py @@ -1,14 +1,23 @@ -import inspect from functools import lru_cache from pathlib import Path -from typing import Any, Literal, Callable - +from typing import Any, Literal, Callable, Protocol, runtime_checkable, Union import duckdb import jinja2 import pandas as pd import toml +@runtime_checkable +class DuckView(Protocol): + query: str + + +@runtime_checkable +class DuckMacro(Protocol): + args: list[str] + macro: str + + @lru_cache def get_settings(toml_file: str = "pyproject.toml") -> dict: """ @@ -122,17 +131,57 @@ def __truediv__(self, other: str) -> "DuckUrl": return DuckUrl(f"{url}/{other}") +SourceType = Path | DuckUrl | pd.DataFrame + + +@runtime_checkable +class SourceView(Protocol): + @property + def source(self) -> SourceType: + ... + + class DuckQuery: def __init__(self): self.ddb = duckdb.connect(":memory:") self.https: bool = False + self.variables = {} + self._last_query: DuckResponse | None = None + + def set_jinja_variable(self, name: str, value: Any) -> "DuckQuery": + """ + Set jinja variables that can then be used in queries + """ + self.variables[name] = value + return self + + @property + def last_query(self): + """ + Get query for last view registered + """ + if not self._last_query: + raise ValueError("No previous query to execute") + return self._last_query def activate_https(self) -> None: if self.https is False: self.ddb.execute("install httpfs; load httpfs") - def register(self, name: str, item: pd.DataFrame | DuckUrl | Path) -> None: + def as_source(self, item: SourceView) -> "DuckResponse": + """ + Decorator to convert something implementing SourceView to a DuckResponse + """ + name = item.__name__ # type: ignore + source = getattr(item, "source", None) + + if source is None: + raise ValueError("Class must have a source attribute") + self.register(name, source) + return self.view(name) + + def register(self, name: str, item: SourceType) -> None: if isinstance(item, DuckUrl): self.activate_https() self.ddb.execute( @@ -156,14 +205,38 @@ def add_view(self, name: str, query: str) -> "DuckQuery": self.ddb.execute(f"CREATE OR REPLACE VIEW {name} AS {query}") return self + def as_view(self, cls: DuckView) -> "DuckResponse": + """ + Decorator to convert something implementing DuckView to a DuckResponse + """ + + query = getattr(cls, "query", None) + + if query is None: + raise ValueError("Class must have a query method") + + store_as_view = getattr(cls, "store_as_view", None) # type: ignore + + if store_as_view is None: + store_as_view: str = cls.__name__ # type: ignore + + return self.query(query, store_as=store_as_view) + + def view(self, view_name: str): + """ """ + return self.query(f"SELECT * FROM {view_name}") + def query( - self, query: str | Path, store_as: str | None = None, **kwargs: Any + self, query: str | Path | DuckView, store_as: str | None = None, **kwargs: Any ) -> DuckResponse: """ Execute a query """ + if isinstance(query, DuckView): + return self.as_view(query) + + query_vars = self.variables | kwargs - # if the query is a path, read it in if isinstance(query, Path) or query.endswith(".sql"): path = Path(query) if not path.exists(): @@ -189,11 +262,11 @@ def process_kwarg(key: str, value: Any) -> Any: return value - if kwargs: + if query_vars: env = jinja2.Environment() template = env.from_string(query) - args = {k: process_kwarg(k, v) for k, v in kwargs.items()} + args = {k: process_kwarg(k, v) for k, v in query_vars.items()} rendered_query = template.render(**args) else: @@ -202,9 +275,36 @@ def process_kwarg(key: str, value: Any) -> Any: if store_as: self.ddb.execute(f"CREATE OR REPLACE VIEW {store_as} AS {rendered_query}") rendered_query = f"SELECT * FROM {store_as}" - return DuckResponse(self, rendered_query) + + response = DuckResponse(self, rendered_query) + + self._last_query = response + return response + + def as_macro(self, item: DuckMacro): + name = item.__name__ # type: ignore + + args = getattr(item, "args", None) + + if args is None: + raise ValueError("Macro must have an args attribute") + + macro = getattr(item, "macro", None) + + if macro is None: + raise ValueError("Macro must have a macro method") + + macro_query = f""" + CREATE OR REPLACE MACRO {name}({", ".join(args)}) AS + {macro} + """ + self.query(macro_query).run() + + return item def macro(self, func: Callable[..., str]) -> None: + # depricated: converts a function + # prefer 'as_macro' for clarity # get function name name = func.__name__ # get arguments