Skip to content

Commit

Permalink
Add new decorator based approach for duck queries
Browse files Browse the repository at this point in the history
  • Loading branch information
ajparsons committed Aug 9, 2023
1 parent edbef72 commit a026045
Showing 1 changed file with 109 additions and 9 deletions.
118 changes: 109 additions & 9 deletions src/data_common/db/duck.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
import inspect
from functools import lru_cache
from pathlib import Path
from typing import Any, Literal, Callable

from typing import Any, Literal, Callable, Protocol, runtime_checkable, Union
import duckdb
import jinja2
import pandas as pd
import toml


@runtime_checkable
class DuckView(Protocol):
query: str


@runtime_checkable
class DuckMacro(Protocol):
args: list[str]
macro: str


@lru_cache
def get_settings(toml_file: str = "pyproject.toml") -> dict:
"""
Expand Down Expand Up @@ -122,17 +131,57 @@ def __truediv__(self, other: str) -> "DuckUrl":
return DuckUrl(f"{url}/{other}")


SourceType = Path | DuckUrl | pd.DataFrame


@runtime_checkable
class SourceView(Protocol):
@property
def source(self) -> SourceType:
...


class DuckQuery:
def __init__(self):
self.ddb = duckdb.connect(":memory:")
self.https: bool = False
self.variables = {}
self._last_query: DuckResponse | None = None

def set_jinja_variable(self, name: str, value: Any) -> "DuckQuery":
"""
Set jinja variables that can then be used in queries
"""
self.variables[name] = value
return self

@property
def last_query(self):
"""
Get query for last view registered
"""
if not self._last_query:
raise ValueError("No previous query to execute")
return self._last_query

def activate_https(self) -> None:
if self.https is False:
self.ddb.execute("install httpfs; load httpfs")

def register(self, name: str, item: pd.DataFrame | DuckUrl | Path) -> None:
def as_source(self, item: SourceView) -> "DuckResponse":
"""
Decorator to convert something implementing SourceView to a DuckResponse
"""
name = item.__name__ # type: ignore
source = getattr(item, "source", None)

if source is None:
raise ValueError("Class must have a source attribute")

self.register(name, source)
return self.view(name)

def register(self, name: str, item: SourceType) -> None:
if isinstance(item, DuckUrl):
self.activate_https()
self.ddb.execute(
Expand All @@ -156,14 +205,38 @@ def add_view(self, name: str, query: str) -> "DuckQuery":
self.ddb.execute(f"CREATE OR REPLACE VIEW {name} AS {query}")
return self

def as_view(self, cls: DuckView) -> "DuckResponse":
"""
Decorator to convert something implementing DuckView to a DuckResponse
"""

query = getattr(cls, "query", None)

if query is None:
raise ValueError("Class must have a query method")

store_as_view = getattr(cls, "store_as_view", None) # type: ignore

if store_as_view is None:
store_as_view: str = cls.__name__ # type: ignore

return self.query(query, store_as=store_as_view)

def view(self, view_name: str):
""" """
return self.query(f"SELECT * FROM {view_name}")

def query(
self, query: str | Path, store_as: str | None = None, **kwargs: Any
self, query: str | Path | DuckView, store_as: str | None = None, **kwargs: Any
) -> DuckResponse:
"""
Execute a query
"""
if isinstance(query, DuckView):
return self.as_view(query)

query_vars = self.variables | kwargs

# if the query is a path, read it in
if isinstance(query, Path) or query.endswith(".sql"):
path = Path(query)
if not path.exists():
Expand All @@ -189,11 +262,11 @@ def process_kwarg(key: str, value: Any) -> Any:

return value

if kwargs:
if query_vars:
env = jinja2.Environment()
template = env.from_string(query)

args = {k: process_kwarg(k, v) for k, v in kwargs.items()}
args = {k: process_kwarg(k, v) for k, v in query_vars.items()}

rendered_query = template.render(**args)
else:
Expand All @@ -202,9 +275,36 @@ def process_kwarg(key: str, value: Any) -> Any:
if store_as:
self.ddb.execute(f"CREATE OR REPLACE VIEW {store_as} AS {rendered_query}")
rendered_query = f"SELECT * FROM {store_as}"
return DuckResponse(self, rendered_query)

response = DuckResponse(self, rendered_query)

self._last_query = response
return response

def as_macro(self, item: DuckMacro):
name = item.__name__ # type: ignore

args = getattr(item, "args", None)

if args is None:
raise ValueError("Macro must have an args attribute")

macro = getattr(item, "macro", None)

if macro is None:
raise ValueError("Macro must have a macro method")

macro_query = f"""
CREATE OR REPLACE MACRO {name}({", ".join(args)}) AS
{macro}
"""
self.query(macro_query).run()

return item

def macro(self, func: Callable[..., str]) -> None:
# depricated: converts a function
# prefer 'as_macro' for clarity
# get function name
name = func.__name__
# get arguments
Expand Down

0 comments on commit a026045

Please sign in to comment.