Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch some files to use future annotations, follow ruff guidance to … #1108

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 19 additions & 14 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
# $ pre-commit install
exclude: '^docs/code-comparisons/' # skip the code comparisons directory
repos:
- repo: https://github.com/asottile/pyupgrade
flavour marked this conversation as resolved.
Show resolved Hide resolved
rev: v3.17.0
hooks:
- id: pyupgrade
args: [--py38-plus]
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.5.7
Expand All @@ -15,17 +20,17 @@ repos:
# Run the formatter.
- id: ruff-format
# args: [ --diff ] # Use for previewing changes
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: trailing-whitespace
# ensures files are either empty or end with a blank line
- id: end-of-file-fixer
# sorts requirements
- id: requirements-txt-fixer
# valid python file
- id: check-ast
- repo: https://github.com/pycqa/flake8
rev: 7.1.1
hooks:
- id: flake8
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: trailing-whitespace
# ensures files are either empty or end with a blank line
- id: end-of-file-fixer
# sorts requirements
- id: requirements-txt-fixer
# valid python file
- id: check-ast
- repo: https://github.com/pycqa/flake8
rev: 7.1.1
hooks:
- id: flake8
6 changes: 4 additions & 2 deletions hamilton/ad_hoc_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""A suite of tools for ad-hoc use"""

from __future__ import annotations

import atexit
import importlib.util
import linecache
Expand All @@ -9,7 +11,7 @@
import types
import uuid
from types import ModuleType
from typing import Callable, Optional
from typing import Callable


def _copy_func(f):
Expand Down Expand Up @@ -64,7 +66,7 @@ def create_temporary_module(*functions: Callable, module_name: str = None) -> Mo
return module


def module_from_source(source: str, module_name: Optional[str] = None) -> ModuleType:
def module_from_source(source: str, module_name: str | None = None) -> ModuleType:
"""Create a temporary module from source code."""
module_name = module_name if module_name else _generate_unique_temp_module_name()
module_object = ModuleType(module_name)
Expand Down
65 changes: 34 additions & 31 deletions hamilton/async_driver.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
from __future__ import annotations

import asyncio
import inspect
import logging
import sys
import time
import typing
import uuid
from types import ModuleType
from typing import Any, Dict, Optional, Tuple
from typing import TYPE_CHECKING, Any

import hamilton.lifecycle.base as lifecycle_base
from hamilton import base, driver, graph, lifecycle, node, telemetry
from hamilton.execution.graph_functions import create_error_message
from hamilton.io.materialization import ExtractorFactory, MaterializerFactory

if TYPE_CHECKING:
from types import ModuleType

from hamilton.io.materialization import ExtractorFactory, MaterializerFactory

logger = logging.getLogger(__name__)


async def await_dict_of_tasks(task_dict: Dict[str, typing.Awaitable]) -> Dict[str, Any]:
async def await_dict_of_tasks(task_dict: dict[str, typing.Awaitable]) -> dict[str, Any]:
"""Util to await a dictionary of tasks as asyncio.gather is kind of garbage"""
keys = sorted(task_dict.keys())
coroutines = [task_dict[key] for key in keys]
Expand All @@ -42,7 +47,7 @@ class AsyncGraphAdapter(lifecycle_base.BaseDoNodeExecute, lifecycle.ResultBuilde
def __init__(
self,
result_builder: base.ResultMixin = None,
async_lifecycle_adapters: Optional[lifecycle_base.LifecycleAdapterSet] = None,
async_lifecycle_adapters: lifecycle_base.LifecycleAdapterSet | None = None,
):
"""Creates an AsyncGraphAdapter class. Note this will *only* work with the AsyncDriver class.

Expand All @@ -52,7 +57,7 @@ def __init__(
2. This does *not* work with decorators when the async function is being decorated. That is\
because that function is called directly within the decorator, so we cannot await it.
"""
super(AsyncGraphAdapter, self).__init__()
super().__init__()
self.adapter = (
async_lifecycle_adapters
if async_lifecycle_adapters is not None
Expand All @@ -66,8 +71,8 @@ def do_node_execute(
*,
run_id: str,
node_: node.Node,
kwargs: typing.Dict[str, typing.Any],
task_id: Optional[str] = None,
kwargs: dict[str, typing.Any],
task_id: str | None = None,
) -> typing.Any:
"""Executes a node. Note this doesn't actually execute it -- rather, it returns a task.
This does *not* use async def, as we want it to be awaited on later -- this await is done
Expand Down Expand Up @@ -159,8 +164,8 @@ def build_result(self, **outputs: Any) -> Any:


def separate_sync_from_async(
adapters: typing.List[lifecycle.LifecycleAdapter],
) -> Tuple[typing.List[lifecycle.LifecycleAdapter], typing.List[lifecycle.LifecycleAdapter]]:
adapters: list[lifecycle.LifecycleAdapter],
) -> tuple[list[lifecycle.LifecycleAdapter], list[lifecycle.LifecycleAdapter]]:
"""Separates the sync and async adapters from a list of adapters.
Note this only works with hooks -- we'll be dealing with methods later.

Expand Down Expand Up @@ -196,8 +201,8 @@ def __init__(
self,
config,
*modules,
result_builder: Optional[base.ResultMixin] = None,
adapters: typing.List[lifecycle.LifecycleAdapter] = None,
result_builder: base.ResultMixin | None = None,
adapters: list[lifecycle.LifecycleAdapter] = None,
):
"""Instantiates an asynchronous driver.

Expand Down Expand Up @@ -229,7 +234,7 @@ def __init__(
)
# it will be defaulted by the graph adapter
result_builder = result_builders[0] if len(result_builders) == 1 else None
super(AsyncDriver, self).__init__(
super().__init__(
config,
*modules,
adapter=[
Expand All @@ -246,7 +251,7 @@ def __init__(
)
self.initialized = False

async def ainit(self) -> "AsyncDriver":
async def ainit(self) -> AsyncDriver:
"""Initializes the driver when using async. This only exists for backwards compatibility.
In Hamilton 2.0, we will be using an asynchronous constructor.
See https://dev.to/akarshan/asynchronous-python-magic-how-to-create-awaitable-constructors-with-asyncmixin-18j5.
Expand All @@ -267,12 +272,12 @@ async def ainit(self) -> "AsyncDriver":

async def raw_execute(
self,
final_vars: typing.List[str],
overrides: Dict[str, Any] = None,
final_vars: list[str],
overrides: dict[str, Any] = None,
display_graph: bool = False, # don't care
inputs: Dict[str, Any] = None,
inputs: dict[str, Any] = None,
_fn_graph: graph.FunctionGraph = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Executes the graph, returning a dictionary of strings (node keys) to final results.

:param final_vars: Variables to execute (+ upstream)
Expand Down Expand Up @@ -332,10 +337,10 @@ async def raw_execute(

async def execute(
self,
final_vars: typing.List[str],
overrides: Dict[str, Any] = None,
final_vars: list[str],
overrides: dict[str, Any] = None,
display_graph: bool = False,
inputs: Dict[str, Any] = None,
inputs: dict[str, Any] = None,
) -> Any:
"""Executes computation.

Expand Down Expand Up @@ -386,9 +391,9 @@ async def make_coroutine():

def capture_constructor_telemetry(
self,
error: Optional[str],
modules: Tuple[ModuleType],
config: Dict[str, Any],
error: str | None,
modules: tuple[ModuleType],
config: dict[str, Any],
adapter: base.HamiltonGraphAdapter,
):
"""Ensures we capture constructor telemetry the right way in an async context.
Expand All @@ -407,7 +412,7 @@ def capture_constructor_telemetry(
if loop.is_running():
loop.run_in_executor(
None,
super(AsyncDriver, self).capture_constructor_telemetry,
super().capture_constructor_telemetry,
error,
modules,
config,
Expand Down Expand Up @@ -450,22 +455,20 @@ class Builder(driver.Builder):
"""

def __init__(self):
super(Builder, self).__init__()
super().__init__()

def _not_supported(self, method_name: str, additional_message: str = ""):
raise ValueError(
f"Builder().{method_name}() is not supported for the async driver. {additional_message}"
)

def enable_dynamic_execution(self, *, allow_experimental_mode: bool = False) -> "Builder":
def enable_dynamic_execution(self, *, allow_experimental_mode: bool = False) -> Builder:
self._not_supported("enable_dynamic_execution")

def with_materializers(
self, *materializers: typing.Union[ExtractorFactory, MaterializerFactory]
) -> "Builder":
def with_materializers(self, *materializers: ExtractorFactory | MaterializerFactory) -> Builder:
self._not_supported("with_materializers")

def with_adapter(self, adapter: base.HamiltonGraphAdapter) -> "Builder":
def with_adapter(self, adapter: base.HamiltonGraphAdapter) -> Builder:
self._not_supported(
"with_adapter",
"Use with_adapters instead to pass in the tracker (or other async hooks/methods)",
Expand Down
Loading
Loading