Skip to content

Commit 7974b10

Browse files
author
Flax Authors
committed
Merge pull request #4580 from google:nnx-tabulate-issue
PiperOrigin-RevId: 731490914
2 parents 7f175f8 + 14fda9d commit 7974b10

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

flax/nnx/summary.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,12 @@ def _unflatten_to_simple_structure(xs: list[tuple[tuple[tp.Any, ...], tp.Any]]):
519519
else:
520520
cursor[key] = {}
521521
cursor = cursor[key]
522-
cursor[path[-1]] = value
522+
if isinstance(cursor, list):
523+
assert path[-1] == len(cursor)
524+
cursor.append(value)
525+
else:
526+
assert isinstance(cursor, dict)
527+
cursor[path[-1]] = value
523528
return result
524529

525530
def _as_yaml_str(value) -> str:

tests/nnx/summary_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,31 @@ def __call__(self, x):
7272
self.assertIn('5,790 (23.2 KB)', table_repr[34])
7373
self.assertIn('2 (12 B)', table_repr[34])
7474
self.assertIn('Total Parameters: 6,068 (24.3 KB)', table_repr[37])
75+
76+
def test_multiple_inputs_and_outputs(self):
77+
class CustomMLP(nnx.Module):
78+
def __init__(self):
79+
self.weight = nnx.Param(jnp.ones((4, 8)))
80+
self.bias = nnx.Param(jnp.ones(8))
81+
82+
def __call__(self, x, x2):
83+
y = x @ self.weight
84+
y += self.bias[None]
85+
y += x2
86+
return x, y, 2 * y
87+
88+
cmlp = CustomMLP()
89+
x = jnp.ones((1, 4))
90+
x2 = jnp.ones((1, 8))
91+
table_repr = nnx.tabulate(
92+
cmlp, x, x2, console_kwargs=CONSOLE_TEST_KWARGS
93+
).splitlines()
94+
95+
self.assertIn('CustomMLP Summary', table_repr[0])
96+
self.assertIn('float32[1,4]', table_repr[4])
97+
self.assertIn('float32[1,8]', table_repr[5])
98+
self.assertIn('float32[1,8]', table_repr[6])
99+
100+
101+
if __name__ == '__main__':
102+
absltest.main()

0 commit comments

Comments
 (0)