From db7bb73a0154cdc55c35956e860882e384944d9f Mon Sep 17 00:00:00 2001 From: Miller Wilt Date: Tue, 28 Dec 2021 19:58:41 -0500 Subject: [PATCH] feat: prep for 0.1.0 release (#6) * build: update dev requirements * refactor: speed up shift2d operation * feat: add stochastic depth to XMLP * build: update setup.py to prepare for release * fix: ensure drop_path_survival_rate is used everywhere, and error on unknown kwargs --- README.md | 21 +++-- examples/gmlp.py | 49 +++++++---- examples/mlpmixer.py | 49 +++++++---- examples/resmlp.py | 42 +++++---- examples/s2mlp.py | 49 +++++++---- requirements/dev-requirements.in | 2 + requirements/dev-requirements.txt | 60 ++++++++++++- setup.py | 12 +++ src/x_mlps/_x_mlps.py | 141 +++++++++++++++++++++++++++--- 9 files changed, 338 insertions(+), 87 deletions(-) diff --git a/README.md b/README.md index 46038e2..b18388f 100644 --- a/README.md +++ b/README.md @@ -25,9 +25,7 @@ from einops import rearrange from x_mlps import XMLP, Affine, resmlp_block_factory def create_model(patch_size: int, dim: int, depth: int, num_classes: int = 10): - # NOTE: Operating directly on batched data is supported as well. - @hk.vmap - def model_fn(x: jnp.ndarray) -> jnp.ndarray: + def model_fn(x: jnp.ndarray, is_training: bool) -> jnp.ndarray: # Reformat input image into a sequence of patches x = rearrange(x, "(h p1) (w p2) c -> (h w) (p1 p2 c)", p1=patch_size, p2=patch_size) return XMLP( @@ -37,16 +35,16 @@ def create_model(patch_size: int, dim: int, depth: int, num_classes: int = 10): block=resmlp_block_factory, normalization=lambda num_patches, dim, depth, **kwargs: Affine(dim, **kwargs), num_classes=num_classes, - )(x) + )(x, is_training=is_training) - return model_fn + # NOTE: Operating directly on batched data is supported as well. + return hk.vmap(model_fn, in_axes=(0, None)) model = create_model(patch_size=4, dim=384, depth=12) model_fn = hk.transform(model) -model_fn = hk.without_apply_rng(model_fn) rng = jax.random.PRNGKey(0) -params = model_fn.init(rng, jnp.ones((1, 32, 32, 3))) +params = model_fn.init(rng, jnp.ones((1, 32, 32, 3)), False) ``` It's important to note the `XMLP` module _does not_ reformat input data to the form appropriate for whatever block you make use of (e.g., a sequence of patches). @@ -155,3 +153,12 @@ See [LICENSE](LICENSE). volume={abs/2103.17239} } ``` + +```bibtex +@inproceedings{Huang2016DeepNW, + title={Deep Networks with Stochastic Depth}, + author={Gao Huang and Yu Sun and Zhuang Liu and Daniel Sedra and Kilian Q. Weinberger}, + booktitle={ECCV}, + year={2016} +} +``` diff --git a/examples/gmlp.py b/examples/gmlp.py index ee635fe..230b4e2 100644 --- a/examples/gmlp.py +++ b/examples/gmlp.py @@ -36,8 +36,7 @@ def collate_fn(batch): def create_model(patch_size: int, dim: int, depth: int, num_classes: int = 10): - @hk.vmap - def model_fn(x: jnp.ndarray) -> jnp.ndarray: + def model_fn(x: jnp.ndarray, is_training: bool) -> jnp.ndarray: x = rearrange(x, "(h p1) (w p2) c -> (h w) (p1 p2 c)", p1=patch_size, p2=patch_size) return XMLP( num_patches=x.shape[-2], @@ -48,9 +47,9 @@ def model_fn(x: jnp.ndarray) -> jnp.ndarray: num_classes=num_classes, block_sublayers_ff_dim_hidden=FF_DIM_HIDDEN, block_sublayers_postnorm=lambda num_patches, dim, depth, **kwargs: LayerScale(dim, depth), - )(x) + )(x, is_training=is_training) - return model_fn + return hk.vmap(model_fn, in_axes=(0, None)) def create_loss_fn(num_classes: int = 10, alpha: float = 0.1, reduction: str = "mean"): @@ -67,8 +66,8 @@ def loss_fn(y_hat: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: def create_step_fn(loss_fn, optimizer: optax.GradientTransformation): @jax.jit - def step_fn(params: hk.Params, opt_state: optax.OptState, x: jnp.ndarray, y: jnp.ndarray): - loss_value, grads = jax.value_and_grad(loss_fn)(params, x, y) + def step_fn(params: hk.Params, rng: jax.random.KeyArray, opt_state: optax.OptState, x: jnp.ndarray, y: jnp.ndarray): + loss_value, grads = jax.value_and_grad(loss_fn)(params, rng, x, y) updates, opt_state = optimizer.update(grads, opt_state, params) params = optax.apply_updates(params, updates) @@ -85,26 +84,39 @@ def fit( opt_state: optax.OptState, train_loader: DataLoader, val_loader: DataLoader, + rng: jax.random.KeyArray, num_epochs: int = 1, ): - def forward(params: hk.Params, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: - y_hat = model_fn(params, x) + def forward(params: hk.Params, rng: jax.random.KeyArray, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + y_hat = model_fn(params, rng, x, True) return loss_fn(y_hat, y) + @jax.jit + def predict(params: hk.Params, x: jnp.ndarray) -> jnp.ndarray: + return model_fn(params, None, x, False) + step = create_step_fn(forward, optimizer) + # Ensure the model can even be run. + for i, (x, _) in enumerate(val_loader): + x = jnp.array(x) + _ = predict(params, x) + if i >= 2: + break + for epoch in range(num_epochs): with tqdm(total=len(train_loader), desc=f"Epoch {epoch + 1}/{num_epochs}", unit="batch") as t: for x, y in train_loader: x, y = jnp.array(x), jnp.array(y) - params, opt_state, loss_value = step(params, opt_state, x, y) + rng, subkey = jax.random.split(rng) + params, opt_state, loss_value = step(params, subkey, opt_state, x, y) t.set_postfix(loss=loss_value) t.update() num_correct, num_samples = 0, 0 losses = [] for x, y in val_loader: x, y = jnp.array(x), jnp.array(y) - y_hat = model_fn(params, x) + y_hat = predict(params, x) loss_value = loss_fn(y_hat, y) losses.append(loss_value.item()) @@ -138,12 +150,9 @@ def main(): # Create and initalize model model = create_model(patch_size=PATCH_SIZE, dim=DIM, depth=DEPTH) model_fn = hk.transform(model) - model_fn = hk.without_apply_rng(model_fn) - - rng = jax.random.PRNGKey(0) - params = model_fn.init(rng, jnp.ones((1, 32, 32, 3))) - print(hk.experimental.tabulate(model_fn.apply)(jnp.ones((1, 32, 32, 3)))) + key, subkey = jax.random.split(jax.random.PRNGKey(0)) + params = model_fn.init(subkey, jnp.ones((1, 32, 32, 3)), False) # Create and initialize optimizer schedule = optax.warmup_cosine_decay_schedule( @@ -159,7 +168,15 @@ def main(): # Train! loss_fn = create_loss_fn() params, opt_state = fit( - jax.jit(model_fn.apply), loss_fn, optimizer, params, opt_state, train_loader, val_loader, num_epochs=NUM_EPOCHS + model_fn.apply, + loss_fn, + optimizer, + params, + opt_state, + train_loader, + val_loader, + rng=key, + num_epochs=NUM_EPOCHS, ) diff --git a/examples/mlpmixer.py b/examples/mlpmixer.py index a8622ba..c359a3c 100644 --- a/examples/mlpmixer.py +++ b/examples/mlpmixer.py @@ -36,8 +36,7 @@ def collate_fn(batch): def create_model(patch_size: int, dim: int, depth: int, num_classes: int = 10): - @hk.vmap - def model_fn(x: jnp.ndarray) -> jnp.ndarray: + def model_fn(x: jnp.ndarray, is_training: bool) -> jnp.ndarray: x = rearrange(x, "(h p1) (w p2) c -> (h w) (p1 p2 c)", p1=patch_size, p2=patch_size) return XMLP( num_patches=x.shape[-2], @@ -47,9 +46,9 @@ def model_fn(x: jnp.ndarray) -> jnp.ndarray: normalization=layernorm_factory, num_classes=num_classes, block_sublayer1_ff_dim_hidden=PATCH_FF_DIM_HIDDEN, - )(x) + )(x, is_training=is_training) - return model_fn + return hk.vmap(model_fn, in_axes=(0, None)) def create_loss_fn(num_classes: int = 10, alpha: float = 0.1, reduction: str = "mean"): @@ -66,8 +65,8 @@ def loss_fn(y_hat: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: def create_step_fn(loss_fn, optimizer: optax.GradientTransformation): @jax.jit - def step_fn(params: hk.Params, opt_state: optax.OptState, x: jnp.ndarray, y: jnp.ndarray): - loss_value, grads = jax.value_and_grad(loss_fn)(params, x, y) + def step_fn(params: hk.Params, rng: jax.random.KeyArray, opt_state: optax.OptState, x: jnp.ndarray, y: jnp.ndarray): + loss_value, grads = jax.value_and_grad(loss_fn)(params, rng, x, y) updates, opt_state = optimizer.update(grads, opt_state, params) params = optax.apply_updates(params, updates) @@ -84,26 +83,39 @@ def fit( opt_state: optax.OptState, train_loader: DataLoader, val_loader: DataLoader, + rng: jax.random.KeyArray, num_epochs: int = 1, ): - def forward(params: hk.Params, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: - y_hat = model_fn(params, x) + def forward(params: hk.Params, rng: jax.random.KeyArray, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + y_hat = model_fn(params, rng, x, True) return loss_fn(y_hat, y) + @jax.jit + def predict(params: hk.Params, x: jnp.ndarray) -> jnp.ndarray: + return model_fn(params, None, x, False) + step = create_step_fn(forward, optimizer) + # Ensure the model can even be run. + for i, (x, _) in enumerate(val_loader): + x = jnp.array(x) + _ = predict(params, x) + if i >= 2: + break + for epoch in range(num_epochs): with tqdm(total=len(train_loader), desc=f"Epoch {epoch + 1}/{num_epochs}", unit="batch") as t: for x, y in train_loader: x, y = jnp.array(x), jnp.array(y) - params, opt_state, loss_value = step(params, opt_state, x, y) + rng, subkey = jax.random.split(rng) + params, opt_state, loss_value = step(params, subkey, opt_state, x, y) t.set_postfix(loss=loss_value) t.update() num_correct, num_samples = 0, 0 losses = [] for x, y in val_loader: x, y = jnp.array(x), jnp.array(y) - y_hat = model_fn(params, x) + y_hat = predict(params, x) loss_value = loss_fn(y_hat, y) losses.append(loss_value.item()) @@ -137,12 +149,9 @@ def main(): # Create and initalize model model = create_model(patch_size=PATCH_SIZE, dim=DIM, depth=DEPTH) model_fn = hk.transform(model) - model_fn = hk.without_apply_rng(model_fn) - - rng = jax.random.PRNGKey(0) - params = model_fn.init(rng, jnp.ones((1, 32, 32, 3))) - print(hk.experimental.tabulate(model_fn.apply)(jnp.ones((1, 32, 32, 3)))) + key, subkey = jax.random.split(jax.random.PRNGKey(0)) + params = model_fn.init(subkey, jnp.ones((1, 32, 32, 3)), False) # Create and initialize optimizer schedule = optax.warmup_cosine_decay_schedule( @@ -158,7 +167,15 @@ def main(): # Train! loss_fn = create_loss_fn() params, opt_state = fit( - jax.jit(model_fn.apply), loss_fn, optimizer, params, opt_state, train_loader, val_loader, num_epochs=NUM_EPOCHS + model_fn.apply, + loss_fn, + optimizer, + params, + opt_state, + train_loader, + val_loader, + rng=key, + num_epochs=NUM_EPOCHS, ) diff --git a/examples/resmlp.py b/examples/resmlp.py index f660209..be43415 100644 --- a/examples/resmlp.py +++ b/examples/resmlp.py @@ -35,8 +35,7 @@ def collate_fn(batch): def create_model(patch_size: int, dim: int, depth: int, num_classes: int = 10): - @hk.vmap - def model_fn(x: jnp.ndarray) -> jnp.ndarray: + def model_fn(x: jnp.ndarray, is_training: bool) -> jnp.ndarray: x = rearrange(x, "(h p1) (w p2) c -> (h w) (p1 p2 c)", p1=patch_size, p2=patch_size) return XMLP( num_patches=x.shape[-2], @@ -45,9 +44,9 @@ def model_fn(x: jnp.ndarray) -> jnp.ndarray: block=resmlp_block_factory, normalization=lambda num_patches, dim, depth, **kwargs: Affine(dim, **kwargs), num_classes=num_classes, - )(x) + )(x, is_training=is_training) - return model_fn + return hk.vmap(model_fn, in_axes=(0, None)) def create_loss_fn(num_classes: int = 10, alpha: float = 0.1, reduction: str = "mean"): @@ -64,8 +63,8 @@ def loss_fn(y_hat: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: def create_step_fn(loss_fn, optimizer: optax.GradientTransformation): @jax.jit - def step_fn(params: hk.Params, opt_state: optax.OptState, x: jnp.ndarray, y: jnp.ndarray): - loss_value, grads = jax.value_and_grad(loss_fn)(params, x, y) + def step_fn(params: hk.Params, rng: jax.random.KeyArray, opt_state: optax.OptState, x: jnp.ndarray, y: jnp.ndarray): + loss_value, grads = jax.value_and_grad(loss_fn)(params, rng, x, y) updates, opt_state = optimizer.update(grads, opt_state, params) params = optax.apply_updates(params, updates) @@ -82,26 +81,32 @@ def fit( opt_state: optax.OptState, train_loader: DataLoader, val_loader: DataLoader, + rng: jax.random.KeyArray, num_epochs: int = 1, ): - def forward(params: hk.Params, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: - y_hat = model_fn(params, x) + def forward(params: hk.Params, rng: jax.random.KeyArray, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + y_hat = model_fn(params, rng, x, True) return loss_fn(y_hat, y) + @jax.jit + def predict(params: hk.Params, x: jnp.ndarray) -> jnp.ndarray: + return model_fn(params, None, x, False) + step = create_step_fn(forward, optimizer) for epoch in range(num_epochs): with tqdm(total=len(train_loader), desc=f"Epoch {epoch + 1}/{num_epochs}", unit="batch") as t: for x, y in train_loader: x, y = jnp.array(x), jnp.array(y) - params, opt_state, loss_value = step(params, opt_state, x, y) + rng, subkey = jax.random.split(rng) + params, opt_state, loss_value = step(params, subkey, opt_state, x, y) t.set_postfix(loss=loss_value) t.update() num_correct, num_samples = 0, 0 losses = [] for x, y in val_loader: x, y = jnp.array(x), jnp.array(y) - y_hat = model_fn(params, x) + y_hat = predict(params, x) loss_value = loss_fn(y_hat, y) losses.append(loss_value.item()) @@ -135,12 +140,9 @@ def main(): # Create and initalize model model = create_model(patch_size=PATCH_SIZE, dim=DIM, depth=DEPTH) model_fn = hk.transform(model) - model_fn = hk.without_apply_rng(model_fn) - - rng = jax.random.PRNGKey(0) - params = model_fn.init(rng, jnp.ones((1, 32, 32, 3))) - print(hk.experimental.tabulate(model_fn.apply)(jnp.ones((1, 32, 32, 3)))) + key, subkey = jax.random.split(jax.random.PRNGKey(0)) + params = model_fn.init(subkey, jnp.ones((1, 32, 32, 3)), False) # Create and initialize optimizer schedule = optax.warmup_cosine_decay_schedule( @@ -156,7 +158,15 @@ def main(): # Train! loss_fn = create_loss_fn() params, opt_state = fit( - jax.jit(model_fn.apply), loss_fn, optimizer, params, opt_state, train_loader, val_loader, num_epochs=NUM_EPOCHS + model_fn.apply, + loss_fn, + optimizer, + params, + opt_state, + train_loader, + val_loader, + rng=key, + num_epochs=NUM_EPOCHS, ) diff --git a/examples/s2mlp.py b/examples/s2mlp.py index 921f395..5523a9e 100644 --- a/examples/s2mlp.py +++ b/examples/s2mlp.py @@ -36,8 +36,7 @@ def collate_fn(batch): def create_model(patch_size: int, dim: int, depth: int, num_classes: int = 10): - @hk.vmap - def model_fn(x: jnp.ndarray) -> jnp.ndarray: + def model_fn(x: jnp.ndarray, is_training: bool) -> jnp.ndarray: h, w, _ = x.shape x = rearrange(x, "(h p1) (w p2) c -> (h w) (p1 p2 c)", p1=patch_size, p2=patch_size) return XMLP( @@ -48,9 +47,9 @@ def model_fn(x: jnp.ndarray) -> jnp.ndarray: normalization=layernorm_factory, num_classes=num_classes, block_sublayer1_ff_shift=create_shift2d_op(h // patch_size, w // patch_size), - )(x) + )(x, is_training=is_training) - return model_fn + return hk.vmap(model_fn, in_axes=(0, None)) def create_loss_fn(num_classes: int = 10, alpha: float = 0.1, reduction: str = "mean"): @@ -67,8 +66,8 @@ def loss_fn(y_hat: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: def create_step_fn(loss_fn, optimizer: optax.GradientTransformation): @jax.jit - def step_fn(params: hk.Params, opt_state: optax.OptState, x: jnp.ndarray, y: jnp.ndarray): - loss_value, grads = jax.value_and_grad(loss_fn)(params, x, y) + def step_fn(params: hk.Params, rng: jax.random.KeyArray, opt_state: optax.OptState, x: jnp.ndarray, y: jnp.ndarray): + loss_value, grads = jax.value_and_grad(loss_fn)(params, rng, x, y) updates, opt_state = optimizer.update(grads, opt_state, params) params = optax.apply_updates(params, updates) @@ -85,26 +84,39 @@ def fit( opt_state: optax.OptState, train_loader: DataLoader, val_loader: DataLoader, + rng: jax.random.KeyArray, num_epochs: int = 1, ): - def forward(params: hk.Params, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: - y_hat = model_fn(params, x) + def forward(params: hk.Params, rng: jax.random.KeyArray, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + y_hat = model_fn(params, rng, x, True) return loss_fn(y_hat, y) + @jax.jit + def predict(params: hk.Params, x: jnp.ndarray) -> jnp.ndarray: + return model_fn(params, None, x, False) + step = create_step_fn(forward, optimizer) + # Ensure the model can even be run. + for i, (x, _) in enumerate(val_loader): + x = jnp.array(x) + _ = predict(params, x) + if i >= 2: + break + for epoch in range(num_epochs): with tqdm(total=len(train_loader), desc=f"Epoch {epoch + 1}/{num_epochs}", unit="batch") as t: for x, y in train_loader: x, y = jnp.array(x), jnp.array(y) - params, opt_state, loss_value = step(params, opt_state, x, y) + rng, subkey = jax.random.split(rng) + params, opt_state, loss_value = step(params, subkey, opt_state, x, y) t.set_postfix(loss=loss_value) t.update() num_correct, num_samples = 0, 0 losses = [] for x, y in val_loader: x, y = jnp.array(x), jnp.array(y) - y_hat = model_fn(params, x) + y_hat = predict(params, x) loss_value = loss_fn(y_hat, y) losses.append(loss_value.item()) @@ -138,12 +150,9 @@ def main(): # Create and initalize model model = create_model(patch_size=PATCH_SIZE, dim=DIM, depth=DEPTH) model_fn = hk.transform(model) - model_fn = hk.without_apply_rng(model_fn) - - rng = jax.random.PRNGKey(0) - params = model_fn.init(rng, jnp.ones((1, 32, 32, 3))) - print(hk.experimental.tabulate(model_fn.apply)(jnp.ones((1, 32, 32, 3)))) + key, subkey = jax.random.split(jax.random.PRNGKey(0)) + params = model_fn.init(subkey, jnp.ones((1, 32, 32, 3)), False) # Create and initialize optimizer schedule = optax.warmup_cosine_decay_schedule( @@ -159,7 +168,15 @@ def main(): # Train! loss_fn = create_loss_fn() params, opt_state = fit( - jax.jit(model_fn.apply), loss_fn, optimizer, params, opt_state, train_loader, val_loader, num_epochs=NUM_EPOCHS + model_fn.apply, + loss_fn, + optimizer, + params, + opt_state, + train_loader, + val_loader, + rng=key, + num_epochs=NUM_EPOCHS, ) diff --git a/requirements/dev-requirements.in b/requirements/dev-requirements.in index 2fe99ad..b924b23 100644 --- a/requirements/dev-requirements.in +++ b/requirements/dev-requirements.in @@ -1,6 +1,7 @@ -c requirements.txt -c jax-gpu.txt bandit +build black flake8 flake8-bugbear @@ -14,3 +15,4 @@ setuptools-scm torch torchvision tqdm +twine diff --git a/requirements/dev-requirements.txt b/requirements/dev-requirements.txt index d1cbb2d..16c89e0 100644 --- a/requirements/dev-requirements.txt +++ b/requirements/dev-requirements.txt @@ -26,8 +26,14 @@ bandit==1.7.1 # via -r dev-requirements.in black==21.12b0 # via -r dev-requirements.in +bleach==4.1.0 + # via readme-renderer +build==0.7.0 + # via -r dev-requirements.in certifi==2021.10.8 # via requests +cffi==1.15.0 + # via cryptography cfgv==3.3.1 # via pre-commit charset-normalizer==2.0.9 @@ -38,12 +44,18 @@ click==8.0.3 # via # black # safety +colorama==0.4.4 + # via twine +cryptography==36.0.1 + # via secretstorage decorator==5.1.0 # via ipython distlib==0.3.4 # via virtualenv dm-tree==0.1.6 # via chex +docutils==0.18.1 + # via readme-renderer dparse==0.5.1 # via safety filelock==3.4.0 @@ -66,6 +78,10 @@ identify==2.4.0 # via pre-commit idna==3.3 # via requests +importlib-metadata==4.10.0 + # via + # keyring + # twine iniconfig==1.1.1 # via pytest ipython==7.30.1 @@ -85,6 +101,12 @@ jaxlib==0.1.75+cuda11.cudnn82 # optax jedi==0.18.1 # via ipython +jeepney==0.7.1 + # via + # keyring + # secretstorage +keyring==23.4.0 + # via twine matplotlib-inline==0.1.3 # via ipython mccabe==0.6.1 @@ -112,6 +134,8 @@ optax @ git+git://github.com/deepmind/optax.git # via -r dev-requirements.in packaging==21.3 # via + # bleach + # build # dparse # pytest # safety @@ -122,12 +146,16 @@ pathspec==0.9.0 # via black pbr==5.8.0 # via stevedore +pep517==0.12.0 + # via build pexpect==4.8.0 # via ipython pickleshare==0.7.5 # via ipython pillow==8.4.0 # via torchvision +pkginfo==1.8.2 + # via twine platformdirs==2.4.0 # via # black @@ -144,10 +172,14 @@ py==1.11.0 # via pytest pycodestyle==2.8.0 # via flake8 +pycparser==2.21 + # via cffi pyflakes==2.4.0 # via flake8 pygments==2.10.0 - # via ipython + # via + # ipython + # readme-renderer pyparsing==3.0.6 # via packaging pytest==6.2.5 @@ -157,8 +189,17 @@ pyyaml==6.0 # bandit # dparse # pre-commit +readme-renderer==32.0 + # via twine requests==2.26.0 - # via safety + # via + # requests-toolbelt + # safety + # twine +requests-toolbelt==0.9.1 + # via twine +rfc3986==1.5.0 + # via twine safety==1.10.3 # via -r dev-requirements.in scipy==1.7.3 @@ -166,6 +207,8 @@ scipy==1.7.3 # -c jax-gpu.txt # jax # jaxlib +secretstorage==3.3.1 + # via keyring setuptools-scm==6.3.2 # via -r dev-requirements.in six==1.16.0 @@ -173,6 +216,7 @@ six==1.16.0 # -c jax-gpu.txt # -c requirements.txt # absl-py + # bleach # dm-tree # virtualenv smmap==5.0.0 @@ -187,6 +231,8 @@ toml==0.10.2 tomli==1.2.3 # via # black + # build + # pep517 # setuptools-scm toolz==0.11.2 # via chex @@ -197,11 +243,15 @@ torch==1.10.1 torchvision==0.11.2 # via -r dev-requirements.in tqdm==4.62.3 - # via -r dev-requirements.in + # via + # -r dev-requirements.in + # twine traitlets==5.1.1 # via # ipython # matplotlib-inline +twine==3.7.1 + # via -r dev-requirements.in typing-extensions==4.0.1 # via # -c jax-gpu.txt @@ -216,6 +266,10 @@ virtualenv==20.10.0 # via pre-commit wcwidth==0.2.5 # via prompt-toolkit +webencodings==0.5.1 + # via bleach +zipp==3.6.0 + # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/setup.py b/setup.py index a3921c6..8d8a76a 100644 --- a/setup.py +++ b/setup.py @@ -7,6 +7,7 @@ REQUIREMENTS = HERE.joinpath("requirements", "requirements.in").read_text().split() setup( + # Metadata name="x-mlps", author="Miller Wilt", author_email="miller@pyriteai.com", @@ -14,6 +15,17 @@ description="Configurable MLPs built on JAX and Haiku", long_description=README, long_description_content_type="text/markdown", + url="https://github.com/PyriteAI/x-mlps", + keywords=["artificial intelligence", "machine learning", "jax", "haiku", "multilayer perceptron", "mlp"], + classifiers=[ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.8", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Typing :: Typed", + ], + # Options packages=find_packages("src"), package_dir={"": "src"}, install_requires=REQUIREMENTS, diff --git a/src/x_mlps/_x_mlps.py b/src/x_mlps/_x_mlps.py index 374e319..04220ac 100644 --- a/src/x_mlps/_x_mlps.py +++ b/src/x_mlps/_x_mlps.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Optional, Protocol, Sequence +from typing import Any, Callable, Optional, Protocol, Sequence, Union import haiku as hk import jax @@ -41,18 +41,45 @@ def create_shift2d_op(height: int, width: int, amount: int = 1) -> Callable[[jnp """ def shift2d(x: jnp.ndarray) -> jnp.ndarray: - c = x.shape[-1] x = rearrange(x, "... (h w) c -> ... h w c", h=height, w=width) - x = x.at[amount:, :, : c // 4].set(x[:-amount, :, : c // 4]) - x = x.at[:-amount, :, c // 4 : c // 2].set(x[amount:, :, c // 4 : c // 2]) - x = x.at[:, amount:, c // 2 : 3 * c // 4].set(x[:, :-amount, c // 2 : 3 * c // 4]) - x = x.at[:, :-amount, 3 * c // 4 : c].set(x[:, amount:, 3 * c // 4 : c]) + x1, x2, x3, x4 = jnp.split(x, 4, axis=-1) + x1 = x1.at[amount:].set(x1[:-amount]) + x2 = x2.at[:-amount].set(x2[amount:]) + x3 = x3.at[:, amount:].set(x3[:, :-amount]) + x4 = x4.at[:, :-amount].set(x4[:, amount:]) + x = jnp.concatenate([x1, x2, x3, x4], axis=-1) x = rearrange(x, "... h w c -> ... (h w) c") return x return shift2d +class SampleDropout(hk.Module): + """Randomly drop the input with a given probability. + + This is equivalent to Stochastic Depth when applied to the output of a network path¹. + + Args: + rate (float): Probability of dropping an element. + name (str, optional): Name of the module. + + References: + 1. Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382). + """ + + def __init__(self, rate: float, name: Optional[str] = None): + super().__init__(name=name) + + self.rate = rate + + def __call__(self, x: jnp.ndarray, *, is_training: bool) -> jnp.ndarray: + if is_training: + return hk.cond( + jax.random.bernoulli(hk.next_rng_key(), 1 - self.rate), lambda x: x, lambda x: jnp.zeros_like(x), x + ) + return x + + class Affine(hk.Module): """Affine transform layer as described in ResMLP¹. @@ -143,6 +170,9 @@ def __init__( ): super().__init__(name=name) + if kwargs: + raise KeyError(f"unknown keyword arguments: {list(kwargs.keys())}") + if norm is None: norm = layernorm_factory if activation is None: @@ -201,6 +231,9 @@ def __init__( ): super().__init__(name=name) + if kwargs: + raise KeyError(f"unknown keyword arguments: {list(kwargs.keys())}") + self.num_patches = num_patches self.dim = dim self.depth = depth @@ -235,6 +268,9 @@ class ResMLPXPatchFeedForward(hk.Module): def __init__(self, num_patches: int, dim: int, depth: int, name: Optional[str] = None, **kwargs: Any): super().__init__(name=name) + if kwargs: + raise KeyError(f"unknown keyword arguments: {list(kwargs.keys())}") + self.num_patches = num_patches self.dim = dim self.depth = depth @@ -279,6 +315,8 @@ def __init__( super().__init__(name=name) sgu_kwargs, kwargs = group_by_prefix_and_trim("sgu_", kwargs) + if kwargs: + raise KeyError(f"unknown keyword arguments: {list(kwargs.keys())}") if sgu is None: sgu = sgu_factory @@ -331,6 +369,9 @@ def __init__( ): super().__init__(name=name) + if kwargs: + raise KeyError(f"unknown keyword arguments: {list(kwargs.keys())}") + self.num_patches = num_patches self.dim = dim self.depth = depth @@ -358,6 +399,8 @@ class XSublayer(hk.Module): prenorm (XModuleFactory, optional): Pre-normalization layer factory function. Defaults to `None`. postnorm (XModuleFactory, optional): Post-normalization layer factory function. Defaults to `None`. residual (bool): Whether to add a residual/skip connection. Defaults to `True`. + drop_path_survival_rate (float): Probability of the core computation being active (not dropped). Only applicable + if `residual` is `True`. Defaults to 1.0. name (str, optional): The name of the module. Defaults to None. **kwargs: All arguments starting with "ff_" are passed to the feedforward layer factory function. All arguments starting with "prenorm_" are passed to the pre-normalization layer factory function. @@ -373,6 +416,7 @@ def __init__( prenorm: Optional[XModuleFactory] = None, postnorm: Optional[XModuleFactory] = None, residual: bool = True, + drop_path_survival_rate: float = 1.0, name: Optional[str] = None, **kwargs: Any, ): @@ -381,6 +425,8 @@ def __init__( ff_kwargs, kwargs = group_by_prefix_and_trim("ff_", kwargs) prenorm_kwargs, kwargs = group_by_prefix_and_trim("prenorm_", kwargs) postnorm_kwargs, kwargs = group_by_prefix_and_trim("postnorm_", kwargs) + if kwargs: + raise KeyError(f"unknown keyword arguments: {list(kwargs.keys())}") self.num_patches = num_patches self.dim = dim @@ -389,11 +435,21 @@ def __init__( self.prenorm = prenorm self.postnorm = postnorm self.residual = residual + self.drop_path_survival_rate = drop_path_survival_rate self.ff_kwargs = ff_kwargs self.prenorm_kwargs = prenorm_kwargs self.postnorm_kwargs = postnorm_kwargs - def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + def __call__(self, inputs: jnp.ndarray, *, is_training: bool) -> jnp.ndarray: + """Propogate inputs through the sublayer. + + Args: + inputs (jnp.ndarray): Inputs to the sublayer. + is_training (bool): If `True`, enable training specific features (e.g., dropout). Keyword argument. + + Returns: + jnp.ndarray: Outputs of the sublayer. + """ x = inputs if self.prenorm is not None: x = self.prenorm(self.num_patches, self.dim, self.depth, **self.prenorm_kwargs)(x) @@ -401,7 +457,7 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: if self.postnorm is not None: x = self.postnorm(self.num_patches, self.dim, self.depth, **self.postnorm_kwargs)(x) if self.residual: - x += inputs + x = SampleDropout(1 - self.drop_path_survival_rate)(x, is_training=is_training) + inputs return x @@ -423,6 +479,8 @@ class XBlock(hk.Module): sublayers (Sequence[XSublayerFactory]): Sublayer factory functions. Created sublayers will be stacked in the order of their respective factory function in the sequence. residual (bool): Whether to add a residual/skip connection. Defaults to `False`. + drop_path_survival_rate (float): Probability of the core computation being active (not dropped). Passed directly + to sublayers. This will also be applied at the block level if residual is `True`. Defaults to 1.0. name (str, optional): The name of the module. Defaults to None. **kwargs: All arguments starting with "sublayers_" are passed to all sublayers. All arguments starting with "sublayer{i}_" are passed to the i-th sublayer. @@ -435,6 +493,7 @@ def __init__( depth: int, sublayers: Sequence[XModuleFactory], residual: bool = False, + drop_path_survival_rate: float = 1.0, name: Optional[str] = None, **kwargs: Any, ): @@ -445,23 +504,41 @@ def __init__( for i in range(len(sublayers)): sublayer_kwargs, kwargs = group_by_prefix_and_trim(f"sublayer{i + 1}_", kwargs) sublayers_kwargs.append(sublayer_kwargs) + if kwargs: + raise KeyError(f"unknown keyword arguments: {list(kwargs.keys())}") self.num_patches = num_patches self.dim = dim self.depth = depth self.sublayers = tuple(sublayers) self.residual = residual + self.drop_path_survival_rate = drop_path_survival_rate self.sublayer_common_kwargs = sublayer_common_kwargs self.sublayers_kwargs = sublayers_kwargs - def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + def __call__(self, inputs: jnp.ndarray, *, is_training: bool) -> jnp.ndarray: + """Propogate inputs through the block. + + Args: + inputs (jnp.ndarray): Inputs to the block. + is_training (bool): If `True`, enable training specific features (e.g., dropout). Keyword argument. + + Returns: + jnp.ndarray: Outputs of the block. + """ x = inputs for i, sublayer in enumerate(self.sublayers): sublayer_kwargs = self.sublayer_common_kwargs.copy() sublayer_kwargs.update(self.sublayers_kwargs[i]) - x = sublayer(self.num_patches, self.dim, self.depth, **sublayer_kwargs)(x) + x = sublayer( + self.num_patches, + self.dim, + self.depth, + drop_path_survival_rate=self.drop_path_survival_rate, + **sublayer_kwargs, + )(x, is_training=is_training) if self.residual: - x += inputs + x = SampleDropout(1 - self.drop_path_survival_rate)(x, is_training=is_training) + inputs return x @@ -472,6 +549,10 @@ class XMLP(hk.Module): assumes the input has been formatted appropriately (e.g., a sequence of patches). Before data is processed by the stack of `XBlock` modules, it is first projected to the specified dimension `dim` via a linear layer. + This network can optionally be configured with Stochastic Depth¹, a form of regularization. If enabled, the depth + of the network will be dynamically adjusted during training, with sections of the network being randomly dropped. + The likelihood of dropping a layer can either fixed, or dependent on the depth of the network. + Optionally, the network can be configured to have a classification layer at the end by setting `num_classes` to a non-zero value. In this case, the resulting sequence from stack of `XBlock` modules will be averaged over the sequence dimension before being fed to the classification layer. @@ -489,9 +570,15 @@ class XMLP(hk.Module): block (XBlockFactory): Block factory function. normalization (XModuleFactory, optional): Normalization module factory function. Occurs after the stack of `XBlock` modules. Useful for pre-normalization architectures. Defaults to None. + stochastic_depth (Union[bool, float], optional): Whether to use stochastic depth. If `True`, the surivival rate + of each block follows the linear decay function 1 - 0.5 * (i / depth) for 1 <= i <= depth. If `False`, the + survival rate is 1.0. If a float, the survival rate is set to this value. Defaults to False. num_classes (int, optional): Number of classes in the classification layer. Defaults to None. name (str, optional): The name of the module. Defaults to None. **kwargs: All arguments starting with "block_" are passed to all blocks. + + References: + 1. Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382). """ def __init__( @@ -501,6 +588,7 @@ def __init__( depth: int, block: XModuleFactory, normalization: Optional[XModuleFactory] = None, + stochastic_depth: Union[float, bool] = False, num_classes: Optional[int] = None, name: Optional[str] = None, **kwargs: Any, @@ -508,19 +596,45 @@ def __init__( super().__init__(name=name) block_kwargs, kwargs = group_by_prefix_and_trim("block_", kwargs) + if kwargs: + raise KeyError(f"unknown keyword arguments: {list(kwargs.keys())}") + + if isinstance(stochastic_depth, bool) and stochastic_depth: + # This ensures that the first block can be dropped as well. + drop_path_survival_rates = jnp.linspace(1.0, 0.5, num=depth + 1)[1:] + elif isinstance(stochastic_depth, float): + drop_path_survival_rates = jnp.full(depth, stochastic_depth) + else: + drop_path_survival_rates = jnp.ones(depth) self.num_patches = num_patches self.dim = dim self.depth = depth self.block = block self.normalization = normalization + self.drop_path_survival_rates = drop_path_survival_rates self.num_classes = num_classes self.block_kwargs = block_kwargs - def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + def __call__(self, inputs: jnp.ndarray, *, is_training: bool) -> jnp.ndarray: + """Propogate inputs through the network. + + Args: + inputs (jnp.ndarray): Inputs to the network. + is_training (bool): If `True`, enable training specific features (e.g., dropout). Keyword argument. + + Returns: + jnp.ndarray: Outputs of the network. + """ x = hk.Linear(self.dim, name="proj_in")(inputs) for i in range(self.depth): - x = self.block(self.num_patches, self.dim, i + 1, **self.block_kwargs)(x) + x = self.block( + self.num_patches, + self.dim, + i + 1, + drop_path_survival_rate=self.drop_path_survival_rates[i], + **self.block_kwargs, + )(x, is_training=is_training) if self.normalization is not None: x = self.normalization(self.num_patches, self.dim, self.depth + 1)(x) if self.num_classes is not None: @@ -853,6 +967,7 @@ def s2mlp_block_factory(num_patches: int, dim: int, depth: int, name: Optional[s "LayerScale", "MLPMixerXPatchFeedForward", "ResMLPXPatchFeedForward", + "SampleDropout", "SpatialGatingUnit", "XBlock", "XChannelFeedForward",