From 3e554610fb58aaea85ed691d2fac4dc2ab9071a5 Mon Sep 17 00:00:00 2001 From: Timo Klein Date: Mon, 28 Oct 2024 09:24:38 +0100 Subject: [PATCH] fix lambda closure --- navix/components.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/navix/components.py b/navix/components.py index 6fa3826..80b32e5 100644 --- a/navix/components.py +++ b/navix/components.py @@ -21,7 +21,6 @@ from __future__ import annotations from typing import Tuple import inspect -import copy from jax import Array from flax import struct @@ -57,15 +56,13 @@ def __init_subclass__(cls, **kwargs): default = getattr(cls, f_name, dataclasses.MISSING) if isinstance(default, Array): # Create a field with a lambda as default factory to prevent mutable default values - f = dataclasses.field(default_factory=lambda: default) + # NOTE: The default value must be set to prevent the lambdas from only capturing the last default value + f = dataclasses.field(default_factory=lambda default=default: default) f.name = f_name f.type = f_type # Remove the field from the dataclass and replace by modified one - setattr(cls, f.name, f) + setattr(cls, f_name, f) - # BUG: locals() only stores a single value for defaults - if str(cls) == "": - import ipdb; ipdb.set_trace(context=21) struct.dataclass(cls, **kwargs) # pytype: disable=wrong-arg-types