From c70c24e08fb086ef2ae94a32395fe33af8fb01c3 Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Thu, 31 Aug 2023 07:37:38 -0700 Subject: [PATCH] WIP for injected kwargs --- hamilton/function_modifiers/adapters.py | 34 +++++++++++++++++++++++++ hamilton/version.py | 2 +- 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/hamilton/function_modifiers/adapters.py b/hamilton/function_modifiers/adapters.py index 333561516..d1a17e02a 100644 --- a/hamilton/function_modifiers/adapters.py +++ b/hamilton/function_modifiers/adapters.py @@ -1,3 +1,4 @@ +import enum import inspect import typing from typing import Any, Callable, Collection, Dict, List, Tuple, Type @@ -18,6 +19,14 @@ from hamilton.registry import LOADER_REGISTRY, SAVER_REGISTRY +class InjectedKwargs(enum.Enum): + """These are special kwargs for data adapters that are declared and automatically injected by Hamilton. + These can""" + + node_name = "__node_name" + node_tags = "__node_tags" + + class AdapterFactory: """Factory for data loaders. This handles the fact that we pass in source(...) and value(...) parameters to the data loaders.""" @@ -60,6 +69,29 @@ def validate(self): f"Extra parameters for loader: {self.adapter_cls} {extra_params}" ) + def resolve_injected_kwargs(self, node_: node.Node) -> Dict[str, Any]: + """Resolves additional keyword arguments from the on which this operates + (passed into or extracted from). We may change how this works in the future, should + data loaders need metadata, but for now we'll be putting this in one function and + it'll be called by the data saver. + + :param node_: + :return: + """ + possible_args = { + InjectedKwargs.node_name.value: node_.name, + InjectedKwargs.node_tags.value: node_.tags, + } + out = {} + declared_arguments = { + **self.adapter_cls.get_optional_arguments(), + **self.adapter_cls.get_required_arguments(), + } + for item in declared_arguments: + if item in declared_arguments: + out[item] = possible_args[item] + return out + def create_loader(self, **resolved_kwargs: Any) -> DataLoader: if not self.adapter_cls.can_load(): raise InvalidDecoratorException(f"Adapter {self.adapter_cls} cannot load data.") @@ -473,6 +505,8 @@ def create_saver_node( adapter_factory = AdapterFactory(saver_cls, **self.kwargs) dependencies, resolved_kwargs = resolve_kwargs(self.kwargs) + injected_kwargs = adapter_factory.resolve_injected_kwargs(node_) + resolved_kwargs.update(injected_kwargs) dependencies_inverted = {v: k for k, v in dependencies.items()} def save_data( diff --git a/hamilton/version.py b/hamilton/version.py index 5fc8b1c7d..cfe2beea1 100644 --- a/hamilton/version.py +++ b/hamilton/version.py @@ -1 +1 @@ -VERSION = (1, 27, 2) +VERSION = (1, 27, 3)