Skip to content

Commit

Permalink
Include the line number of the class declaration in binding keys
Browse files Browse the repository at this point in the history
Previously, we were using only the class name and filename as the
binding key, which caused collisions if the same name was used to
declare more than one (local) class in the same file (as in
test_registry_attrs.py).

In the process, this simplifies the process of getting the stack frames:
we had been retrieving the same frames (redundantly) in two different
ways in the process of defining a field, but the potential sharing was
somewhat obscured because the code was split into very short functions.

Fixes #43.
  • Loading branch information
bcmills committed Oct 15, 2024
1 parent cb530aa commit 0a36a46
Showing 1 changed file with 56 additions and 61 deletions.
117 changes: 56 additions & 61 deletions minject/inject_attrs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
from collections import defaultdict
from dataclasses import dataclass
from platform import python_version
from typing import Any, DefaultDict, Dict, List, Optional, Type, TypeVar

Expand All @@ -11,11 +12,6 @@
_T = TypeVar("_T")
_P = TypeVar("_P")

_DEPTH_OF_INJECT_FIELD_CALLER = 3
_DEPTH_OF_INJECT_DEFINE_CALLER = 3
_DEPTH_OF_INJECT_DEFINE_CALLER_IF_NO_ARGS = 2
_DEPTH_OF_VAR_TO_WHICH_BINDING_IS_ASSIGNED = 2

_INJECT_DEFINE_DEFINE_KWARGS_DEFAULT_VAL: Dict[str, Any] = {}


Expand Down Expand Up @@ -89,60 +85,52 @@ def _get_compatible_attrs_define_kwargs() -> Dict[str, bool]:
return attrs_define_kwargs


def _get_calling_function_name(depth: int) -> str:
return inspect.stack()[depth].function


def _get_calling_function_file(depth: int) -> str:
return inspect.stack()[depth].filename


def _build_key(func_name: str, func_file: str) -> str:
return f"__{func_name}__{func_file}__"


def _get_calling_function_key_from_depth(depth: int) -> str:
func_name = _get_calling_function_name(depth=depth)
func_file = _get_calling_function_file(depth=depth)
return _build_key(func_name=func_name, func_file=func_file)
@dataclass(frozen=True)
class _BindingKey:
__slots__ = ("filename", "class_lineno")
filename: str
class_lineno: int # The line containing the "class" keyword.


def _get_calling_function_key_from_filename_and_key(func_name: str, func_file: str) -> str:
return _build_key(func_name=func_name, func_file=func_file)
_key_binding_mapping: DefaultDict[_BindingKey, dict] = defaultdict(lambda: {})


_key_binding_mapping: DefaultDict[str, dict] = defaultdict(lambda: {})


def _get_init_kwarg_assignment() -> str:
def inject_field(binding=_T, **attr_field_kwargs) -> Any:
"""
get the name of the variable that will be assigned the return
value of the function that calls this function.
Wrapper around attr.field which takes an argument to specify registry
bindings
"""
frame = inspect.currentframe()
outer_frame = inspect.getouterframes(frame)[_DEPTH_OF_VAR_TO_WHICH_BINDING_IS_ASSIGNED]
optional_code_context = inspect.getframeinfo(outer_frame[0]).code_context
if not optional_code_context:
stack = inspect.stack()
# The first frame of the stack is the call to inject_field itself.

# We assume that inject_field is called directly (not via some kind of
# wrapper), so the second frame of the stack should be the field
# declaration. Extract the name of the field.
field_frame = stack[1]
name = ""
if len(field_frame.code_context) > 0:
code = field_frame.code_context[0].strip()
name_and_type = code.split("=", maxsplit=1)[0].rstrip().lstrip()
name = name_and_type.split(":", maxsplit=1)[0].rstrip().lstrip()
if not name:
raise ValueError(
"Could not find the variable to which the binding is assigned. Are you calling inject_field properly?"
)
code_string = optional_code_context[0].strip()
var_and_type = code_string.split("=")[0].rstrip().lstrip()
var = var_and_type.split(":")[0].rstrip().lstrip()
return var

# The third frame of the stack should be the class declaration (containing
# the "class" keyword). We use that line number as the key for looking up
# bindings, so double-check that that assumption holds.
# (If not, our inferred field name is probably wrong too!)
class_frame = stack[2]
if len(class_frame.code_context) < 1 or not class_frame.code_context[0].strip().startswith(
"class "
):
raise ValueError(
"Could not find line containing class declaration. Are you calling inject_field properly?"
)

def inject_field(binding=_T, **attr_field_kwargs) -> Any:
"""
Wrapper around attr.field which takes an argument to specify registry
bindings
"""
# add the binding to the key_binding_mapping to be retrieved in the call
# to inject_define
var_name = _get_init_kwarg_assignment()
_key_binding_mapping[_get_calling_function_key_from_depth(_DEPTH_OF_INJECT_FIELD_CALLER)][
var_name
] = binding
key = _BindingKey(filename=class_frame.filename, class_lineno=class_frame.lineno)
_key_binding_mapping[key][name] = binding
return field(**attr_field_kwargs)


Expand All @@ -157,21 +145,29 @@ def inject_define(
else:
attrs_kwargs = _get_compatible_attrs_define_kwargs()

# this variable represent how deep to look on the call stack
# to determine the name of the function that called this function.
# this is different depending on how a user used the decorator (with or without kwargs)
depth_of_caller = _DEPTH_OF_INJECT_DEFINE_CALLER

def inject_define_inner(cls: Type[_P]) -> Type[_P]:
# apply attr.define to generate static methods
cls = define(cls, **attrs_kwargs)

# get binding to apply to the class
file_of_class_being_bound = _get_calling_function_file(depth_of_caller)
key = _get_calling_function_key_from_filename_and_key(
cls.__name__, file_of_class_being_bound
# Identify the line containing the "class" keyword for cls: that is
# the line number that we used in the binding key for its fields.
class_lineno = None
(lines, start_lineno) = inspect.getsourcelines(cls)
for lineno, line in enumerate(lines, start_lineno):
if line.strip().startswith("class "):
class_lineno = lineno
if class_lineno is None:
raise ValueError(
"Could not find line containing class declaration. Are you calling inject_define properly?"
)

# get bindings to apply to the class
key = _BindingKey(
filename=inspect.getsourcefile(cls),
class_lineno=class_lineno,
)
bindings = _key_binding_mapping[key]
del _key_binding_mapping[key]

# apply attr.define to generate static methods
cls = define(cls, **attrs_kwargs)

# apply the bindings to the class
init_signature = inspect.signature(cls.__init__)
Expand All @@ -186,7 +182,6 @@ def inject_define_inner(cls: Type[_P]) -> Type[_P]:
return cls

if maybe_cls is None:
depth_of_caller = _DEPTH_OF_INJECT_DEFINE_CALLER_IF_NO_ARGS
return inject_define_inner

return inject_define_inner(maybe_cls)

0 comments on commit 0a36a46

Please sign in to comment.