Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix lambda closure
Browse files Browse the repository at this point in the history
timoklein committed Oct 28, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 0d7a30b commit 3e55461
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions navix/components.py
Original file line number Diff line number Diff line change
@@ -21,7 +21,6 @@
from __future__ import annotations
from typing import Tuple
import inspect
import copy

from jax import Array
from flax import struct
@@ -57,15 +56,13 @@ def __init_subclass__(cls, **kwargs):
default = getattr(cls, f_name, dataclasses.MISSING)
if isinstance(default, Array):
# Create a field with a lambda as default factory to prevent mutable default values
f = dataclasses.field(default_factory=lambda: default)
# NOTE: The default value must be set to prevent the lambdas from only capturing the last default value
f = dataclasses.field(default_factory=lambda default=default: default)
f.name = f_name
f.type = f_type
# Remove the field from the dataclass and replace by modified one
setattr(cls, f.name, f)
setattr(cls, f_name, f)

# BUG: locals() only stores a single value for defaults
if str(cls) == "<class 'navix.states.Event'>":
import ipdb; ipdb.set_trace(context=21)
struct.dataclass(cls, **kwargs) # pytype: disable=wrong-arg-types


0 comments on commit 3e55461

Please sign in to comment.