Skip to content
This repository has been archived by the owner on Dec 14, 2024. It is now read-only.

Commit

Permalink
feat: prep for 0.1.0 release (#6)
Browse files Browse the repository at this point in the history
* build: update dev requirements

* refactor: speed up shift2d operation

* feat: add stochastic depth to XMLP

* build: update setup.py to prepare for release

* fix: ensure drop_path_survival_rate is used everywhere, and error on unknown kwargs
  • Loading branch information
mlw214 authored Dec 29, 2021
1 parent ca1d55f commit db7bb73
Show file tree
Hide file tree
Showing 9 changed files with 338 additions and 87 deletions.
21 changes: 14 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ from einops import rearrange
from x_mlps import XMLP, Affine, resmlp_block_factory

def create_model(patch_size: int, dim: int, depth: int, num_classes: int = 10):
# NOTE: Operating directly on batched data is supported as well.
@hk.vmap
def model_fn(x: jnp.ndarray) -> jnp.ndarray:
def model_fn(x: jnp.ndarray, is_training: bool) -> jnp.ndarray:
# Reformat input image into a sequence of patches
x = rearrange(x, "(h p1) (w p2) c -> (h w) (p1 p2 c)", p1=patch_size, p2=patch_size)
return XMLP(
Expand All @@ -37,16 +35,16 @@ def create_model(patch_size: int, dim: int, depth: int, num_classes: int = 10):
block=resmlp_block_factory,
normalization=lambda num_patches, dim, depth, **kwargs: Affine(dim, **kwargs),
num_classes=num_classes,
)(x)
)(x, is_training=is_training)

return model_fn
# NOTE: Operating directly on batched data is supported as well.
return hk.vmap(model_fn, in_axes=(0, None))

model = create_model(patch_size=4, dim=384, depth=12)
model_fn = hk.transform(model)
model_fn = hk.without_apply_rng(model_fn)

rng = jax.random.PRNGKey(0)
params = model_fn.init(rng, jnp.ones((1, 32, 32, 3)))
params = model_fn.init(rng, jnp.ones((1, 32, 32, 3)), False)
```

It's important to note the `XMLP` module _does not_ reformat input data to the form appropriate for whatever block you make use of (e.g., a sequence of patches).
Expand Down Expand Up @@ -155,3 +153,12 @@ See [LICENSE](LICENSE).
volume={abs/2103.17239}
}
```

```bibtex
@inproceedings{Huang2016DeepNW,
title={Deep Networks with Stochastic Depth},
author={Gao Huang and Yu Sun and Zhuang Liu and Daniel Sedra and Kilian Q. Weinberger},
booktitle={ECCV},
year={2016}
}
```
49 changes: 33 additions & 16 deletions examples/gmlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ def collate_fn(batch):


def create_model(patch_size: int, dim: int, depth: int, num_classes: int = 10):
@hk.vmap
def model_fn(x: jnp.ndarray) -> jnp.ndarray:
def model_fn(x: jnp.ndarray, is_training: bool) -> jnp.ndarray:
x = rearrange(x, "(h p1) (w p2) c -> (h w) (p1 p2 c)", p1=patch_size, p2=patch_size)
return XMLP(
num_patches=x.shape[-2],
Expand All @@ -48,9 +47,9 @@ def model_fn(x: jnp.ndarray) -> jnp.ndarray:
num_classes=num_classes,
block_sublayers_ff_dim_hidden=FF_DIM_HIDDEN,
block_sublayers_postnorm=lambda num_patches, dim, depth, **kwargs: LayerScale(dim, depth),
)(x)
)(x, is_training=is_training)

return model_fn
return hk.vmap(model_fn, in_axes=(0, None))


def create_loss_fn(num_classes: int = 10, alpha: float = 0.1, reduction: str = "mean"):
Expand All @@ -67,8 +66,8 @@ def loss_fn(y_hat: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:

def create_step_fn(loss_fn, optimizer: optax.GradientTransformation):
@jax.jit
def step_fn(params: hk.Params, opt_state: optax.OptState, x: jnp.ndarray, y: jnp.ndarray):
loss_value, grads = jax.value_and_grad(loss_fn)(params, x, y)
def step_fn(params: hk.Params, rng: jax.random.KeyArray, opt_state: optax.OptState, x: jnp.ndarray, y: jnp.ndarray):
loss_value, grads = jax.value_and_grad(loss_fn)(params, rng, x, y)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)

Expand All @@ -85,26 +84,39 @@ def fit(
opt_state: optax.OptState,
train_loader: DataLoader,
val_loader: DataLoader,
rng: jax.random.KeyArray,
num_epochs: int = 1,
):
def forward(params: hk.Params, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
y_hat = model_fn(params, x)
def forward(params: hk.Params, rng: jax.random.KeyArray, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
y_hat = model_fn(params, rng, x, True)
return loss_fn(y_hat, y)

@jax.jit
def predict(params: hk.Params, x: jnp.ndarray) -> jnp.ndarray:
return model_fn(params, None, x, False)

step = create_step_fn(forward, optimizer)

# Ensure the model can even be run.
for i, (x, _) in enumerate(val_loader):
x = jnp.array(x)
_ = predict(params, x)
if i >= 2:
break

for epoch in range(num_epochs):
with tqdm(total=len(train_loader), desc=f"Epoch {epoch + 1}/{num_epochs}", unit="batch") as t:
for x, y in train_loader:
x, y = jnp.array(x), jnp.array(y)
params, opt_state, loss_value = step(params, opt_state, x, y)
rng, subkey = jax.random.split(rng)
params, opt_state, loss_value = step(params, subkey, opt_state, x, y)
t.set_postfix(loss=loss_value)
t.update()
num_correct, num_samples = 0, 0
losses = []
for x, y in val_loader:
x, y = jnp.array(x), jnp.array(y)
y_hat = model_fn(params, x)
y_hat = predict(params, x)
loss_value = loss_fn(y_hat, y)

losses.append(loss_value.item())
Expand Down Expand Up @@ -138,12 +150,9 @@ def main():
# Create and initalize model
model = create_model(patch_size=PATCH_SIZE, dim=DIM, depth=DEPTH)
model_fn = hk.transform(model)
model_fn = hk.without_apply_rng(model_fn)

rng = jax.random.PRNGKey(0)
params = model_fn.init(rng, jnp.ones((1, 32, 32, 3)))

print(hk.experimental.tabulate(model_fn.apply)(jnp.ones((1, 32, 32, 3))))
key, subkey = jax.random.split(jax.random.PRNGKey(0))
params = model_fn.init(subkey, jnp.ones((1, 32, 32, 3)), False)

# Create and initialize optimizer
schedule = optax.warmup_cosine_decay_schedule(
Expand All @@ -159,7 +168,15 @@ def main():
# Train!
loss_fn = create_loss_fn()
params, opt_state = fit(
jax.jit(model_fn.apply), loss_fn, optimizer, params, opt_state, train_loader, val_loader, num_epochs=NUM_EPOCHS
model_fn.apply,
loss_fn,
optimizer,
params,
opt_state,
train_loader,
val_loader,
rng=key,
num_epochs=NUM_EPOCHS,
)


Expand Down
49 changes: 33 additions & 16 deletions examples/mlpmixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ def collate_fn(batch):


def create_model(patch_size: int, dim: int, depth: int, num_classes: int = 10):
@hk.vmap
def model_fn(x: jnp.ndarray) -> jnp.ndarray:
def model_fn(x: jnp.ndarray, is_training: bool) -> jnp.ndarray:
x = rearrange(x, "(h p1) (w p2) c -> (h w) (p1 p2 c)", p1=patch_size, p2=patch_size)
return XMLP(
num_patches=x.shape[-2],
Expand All @@ -47,9 +46,9 @@ def model_fn(x: jnp.ndarray) -> jnp.ndarray:
normalization=layernorm_factory,
num_classes=num_classes,
block_sublayer1_ff_dim_hidden=PATCH_FF_DIM_HIDDEN,
)(x)
)(x, is_training=is_training)

return model_fn
return hk.vmap(model_fn, in_axes=(0, None))


def create_loss_fn(num_classes: int = 10, alpha: float = 0.1, reduction: str = "mean"):
Expand All @@ -66,8 +65,8 @@ def loss_fn(y_hat: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:

def create_step_fn(loss_fn, optimizer: optax.GradientTransformation):
@jax.jit
def step_fn(params: hk.Params, opt_state: optax.OptState, x: jnp.ndarray, y: jnp.ndarray):
loss_value, grads = jax.value_and_grad(loss_fn)(params, x, y)
def step_fn(params: hk.Params, rng: jax.random.KeyArray, opt_state: optax.OptState, x: jnp.ndarray, y: jnp.ndarray):
loss_value, grads = jax.value_and_grad(loss_fn)(params, rng, x, y)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)

Expand All @@ -84,26 +83,39 @@ def fit(
opt_state: optax.OptState,
train_loader: DataLoader,
val_loader: DataLoader,
rng: jax.random.KeyArray,
num_epochs: int = 1,
):
def forward(params: hk.Params, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
y_hat = model_fn(params, x)
def forward(params: hk.Params, rng: jax.random.KeyArray, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
y_hat = model_fn(params, rng, x, True)
return loss_fn(y_hat, y)

@jax.jit
def predict(params: hk.Params, x: jnp.ndarray) -> jnp.ndarray:
return model_fn(params, None, x, False)

step = create_step_fn(forward, optimizer)

# Ensure the model can even be run.
for i, (x, _) in enumerate(val_loader):
x = jnp.array(x)
_ = predict(params, x)
if i >= 2:
break

for epoch in range(num_epochs):
with tqdm(total=len(train_loader), desc=f"Epoch {epoch + 1}/{num_epochs}", unit="batch") as t:
for x, y in train_loader:
x, y = jnp.array(x), jnp.array(y)
params, opt_state, loss_value = step(params, opt_state, x, y)
rng, subkey = jax.random.split(rng)
params, opt_state, loss_value = step(params, subkey, opt_state, x, y)
t.set_postfix(loss=loss_value)
t.update()
num_correct, num_samples = 0, 0
losses = []
for x, y in val_loader:
x, y = jnp.array(x), jnp.array(y)
y_hat = model_fn(params, x)
y_hat = predict(params, x)
loss_value = loss_fn(y_hat, y)

losses.append(loss_value.item())
Expand Down Expand Up @@ -137,12 +149,9 @@ def main():
# Create and initalize model
model = create_model(patch_size=PATCH_SIZE, dim=DIM, depth=DEPTH)
model_fn = hk.transform(model)
model_fn = hk.without_apply_rng(model_fn)

rng = jax.random.PRNGKey(0)
params = model_fn.init(rng, jnp.ones((1, 32, 32, 3)))

print(hk.experimental.tabulate(model_fn.apply)(jnp.ones((1, 32, 32, 3))))
key, subkey = jax.random.split(jax.random.PRNGKey(0))
params = model_fn.init(subkey, jnp.ones((1, 32, 32, 3)), False)

# Create and initialize optimizer
schedule = optax.warmup_cosine_decay_schedule(
Expand All @@ -158,7 +167,15 @@ def main():
# Train!
loss_fn = create_loss_fn()
params, opt_state = fit(
jax.jit(model_fn.apply), loss_fn, optimizer, params, opt_state, train_loader, val_loader, num_epochs=NUM_EPOCHS
model_fn.apply,
loss_fn,
optimizer,
params,
opt_state,
train_loader,
val_loader,
rng=key,
num_epochs=NUM_EPOCHS,
)


Expand Down
42 changes: 26 additions & 16 deletions examples/resmlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ def collate_fn(batch):


def create_model(patch_size: int, dim: int, depth: int, num_classes: int = 10):
@hk.vmap
def model_fn(x: jnp.ndarray) -> jnp.ndarray:
def model_fn(x: jnp.ndarray, is_training: bool) -> jnp.ndarray:
x = rearrange(x, "(h p1) (w p2) c -> (h w) (p1 p2 c)", p1=patch_size, p2=patch_size)
return XMLP(
num_patches=x.shape[-2],
Expand All @@ -45,9 +44,9 @@ def model_fn(x: jnp.ndarray) -> jnp.ndarray:
block=resmlp_block_factory,
normalization=lambda num_patches, dim, depth, **kwargs: Affine(dim, **kwargs),
num_classes=num_classes,
)(x)
)(x, is_training=is_training)

return model_fn
return hk.vmap(model_fn, in_axes=(0, None))


def create_loss_fn(num_classes: int = 10, alpha: float = 0.1, reduction: str = "mean"):
Expand All @@ -64,8 +63,8 @@ def loss_fn(y_hat: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:

def create_step_fn(loss_fn, optimizer: optax.GradientTransformation):
@jax.jit
def step_fn(params: hk.Params, opt_state: optax.OptState, x: jnp.ndarray, y: jnp.ndarray):
loss_value, grads = jax.value_and_grad(loss_fn)(params, x, y)
def step_fn(params: hk.Params, rng: jax.random.KeyArray, opt_state: optax.OptState, x: jnp.ndarray, y: jnp.ndarray):
loss_value, grads = jax.value_and_grad(loss_fn)(params, rng, x, y)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)

Expand All @@ -82,26 +81,32 @@ def fit(
opt_state: optax.OptState,
train_loader: DataLoader,
val_loader: DataLoader,
rng: jax.random.KeyArray,
num_epochs: int = 1,
):
def forward(params: hk.Params, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
y_hat = model_fn(params, x)
def forward(params: hk.Params, rng: jax.random.KeyArray, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
y_hat = model_fn(params, rng, x, True)
return loss_fn(y_hat, y)

@jax.jit
def predict(params: hk.Params, x: jnp.ndarray) -> jnp.ndarray:
return model_fn(params, None, x, False)

step = create_step_fn(forward, optimizer)

for epoch in range(num_epochs):
with tqdm(total=len(train_loader), desc=f"Epoch {epoch + 1}/{num_epochs}", unit="batch") as t:
for x, y in train_loader:
x, y = jnp.array(x), jnp.array(y)
params, opt_state, loss_value = step(params, opt_state, x, y)
rng, subkey = jax.random.split(rng)
params, opt_state, loss_value = step(params, subkey, opt_state, x, y)
t.set_postfix(loss=loss_value)
t.update()
num_correct, num_samples = 0, 0
losses = []
for x, y in val_loader:
x, y = jnp.array(x), jnp.array(y)
y_hat = model_fn(params, x)
y_hat = predict(params, x)
loss_value = loss_fn(y_hat, y)

losses.append(loss_value.item())
Expand Down Expand Up @@ -135,12 +140,9 @@ def main():
# Create and initalize model
model = create_model(patch_size=PATCH_SIZE, dim=DIM, depth=DEPTH)
model_fn = hk.transform(model)
model_fn = hk.without_apply_rng(model_fn)

rng = jax.random.PRNGKey(0)
params = model_fn.init(rng, jnp.ones((1, 32, 32, 3)))

print(hk.experimental.tabulate(model_fn.apply)(jnp.ones((1, 32, 32, 3))))
key, subkey = jax.random.split(jax.random.PRNGKey(0))
params = model_fn.init(subkey, jnp.ones((1, 32, 32, 3)), False)

# Create and initialize optimizer
schedule = optax.warmup_cosine_decay_schedule(
Expand All @@ -156,7 +158,15 @@ def main():
# Train!
loss_fn = create_loss_fn()
params, opt_state = fit(
jax.jit(model_fn.apply), loss_fn, optimizer, params, opt_state, train_loader, val_loader, num_epochs=NUM_EPOCHS
model_fn.apply,
loss_fn,
optimizer,
params,
opt_state,
train_loader,
val_loader,
rng=key,
num_epochs=NUM_EPOCHS,
)


Expand Down
Loading

0 comments on commit db7bb73

Please sign in to comment.