Skip to content

Commit 95d7170

Browse files
authored
Add FLAX (#1105)
Includes a smoke test. http://b/207406362
1 parent c2fc249 commit 95d7170

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

Dockerfile.tmpl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ RUN pip install pysal && \
137137
pip install seaborn python-dateutil dask python-igraph && \
138138
pip install pyyaml joblib husl geopy ml_metrics mne pyshp && \
139139
pip install pandas && \
140+
pip install flax && \
140141
# Install h2o from source.
141142
# Use `conda install -c h2oai h2o` once Python 3.7 version is released to conda.
142143
apt-get install -y default-jre-headless && \

tests/test_flax.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import unittest
2+
3+
import jax.numpy as jnp
4+
import numpy as np
5+
6+
from flax import linen as nn
7+
8+
9+
class TestFlax(unittest.TestCase):
10+
11+
def test_bla(self):
12+
x = jnp.full((1, 3, 3, 1), 2.)
13+
mul_reduce = lambda x, y: x * y
14+
y = nn.pooling.pool(x, 1., mul_reduce, (2, 2), (1, 1), 'VALID')
15+
np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2. ** 4))

0 commit comments

Comments
 (0)