Skip to content

Commit

Permalink
Extremely large context size
Browse files Browse the repository at this point in the history
  • Loading branch information
rdes committed Jan 21, 2024
1 parent 69acfd4 commit 89e5766
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 51 deletions.
44 changes: 22 additions & 22 deletions examples/lotka-voltera/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,24 +61,24 @@ def step(state, t):
# return jnp.concatenate([y0[jnp.newaxis, :], ys], axis=0)


## Training environments
# environments = [
# {"alpha": 0.5, "beta": 0.5, "gamma": 0.5, "delta": 0.5},
# {"alpha": 0.5, "beta": 0.75, "gamma": 0.5, "delta": 0.5},
# {"alpha": 0.5, "beta": 1.0, "gamma": 0.5, "delta": 0.5},
# {"alpha": 0.5, "beta": 0.5, "gamma": 0.5, "delta": 0.75},
# {"alpha": 0.5, "beta": 0.75, "gamma": 0.5, "delta": 0.75},
# {"alpha": 0.5, "beta": 1.0, "gamma": 0.5, "delta": 0.75},
# {"alpha": 0.5, "beta": 0.5, "gamma": 0.5, "delta": 1.0},
# {"alpha": 0.5, "beta": 0.75, "gamma": 0.5, "delta": 1.0},
# {"alpha": 0.5, "beta": 1.0, "gamma": 0.5, "delta": 1.0},
# ]

## Lots of data environment
environments = []
for beta in np.linspace(0.5, 1.5, 11):
new_env = {"alpha": 0.5, "beta": beta, "gamma": 0.5, "delta": 0.5}
environments.append(new_env)
# Training environments
environments = [
{"alpha": 0.5, "beta": 0.5, "gamma": 0.5, "delta": 0.5},
{"alpha": 0.5, "beta": 0.75, "gamma": 0.5, "delta": 0.5},
{"alpha": 0.5, "beta": 1.0, "gamma": 0.5, "delta": 0.5},
{"alpha": 0.5, "beta": 0.5, "gamma": 0.5, "delta": 0.75},
{"alpha": 0.5, "beta": 0.75, "gamma": 0.5, "delta": 0.75},
{"alpha": 0.5, "beta": 1.0, "gamma": 0.5, "delta": 0.75},
{"alpha": 0.5, "beta": 0.5, "gamma": 0.5, "delta": 1.0},
{"alpha": 0.5, "beta": 0.75, "gamma": 0.5, "delta": 1.0},
{"alpha": 0.5, "beta": 1.0, "gamma": 0.5, "delta": 1.0},
]

# ## Lots of data environment
# environments = []
# for beta in np.linspace(0.5, 1.5, 11):
# new_env = {"alpha": 0.5, "beta": beta, "gamma": 0.5, "delta": 0.5}
# environments.append(new_env)

# ## Adaptation environments
# environments = [
Expand All @@ -88,11 +88,11 @@ def step(state, t):
# {"alpha": 0.5, "beta": 1.125, "gamma": 0.5, "delta": 1.125},
# ]

# n_traj_per_env = 4 ## training
n_traj_per_env = 32 ## testing
n_traj_per_env = 4 ## training
# n_traj_per_env = 32 ## testing
# n_traj_per_env = 1 ## adaptation

n_steps_per_traj = int(10/0.05)+1 ## from coda
n_steps_per_traj = int(10/0.5)+1 ## from coda
# n_steps_per_traj = 201

data = np.zeros((len(environments), n_traj_per_env, n_steps_per_traj, 2))
Expand Down Expand Up @@ -146,7 +146,7 @@ def animate(i):
plt.show()

# Save t_eval and the solution to a npz file
np.savez('tmp/test_data.npz', t=t_eval, X=data)
np.savez('tmp/train_data.npz', t=t_eval, X=data)

## Save the movie to a small mp4 file
# ani.save('tmp/lotka_volterra.mp4', fps=30, extra_args=['-vcodec', 'libx264'])
Expand Down
60 changes: 34 additions & 26 deletions examples/lotka-voltera/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

## Hyperparams
SEED = 3
context_size = 2
nb_epochs = 5000
context_size = 8000
nb_epochs = 1000



Expand Down Expand Up @@ -95,17 +95,22 @@ class Augmentation(eqx.Module):
def __init__(self, data_size, width_size, depth, context_size, key=None):
keys = generate_new_keys(key, num=12)
self.layers_data = [eqx.nn.Linear(data_size, width_size, key=keys[0]), activation,
# eqx.nn.Linear(width_size, width_size, key=keys[10]), activation,
eqx.nn.Linear(width_size, width_size, key=keys[10]), activation,
eqx.nn.Linear(width_size, width_size, key=keys[1]), activation,
eqx.nn.Linear(width_size, data_size, key=keys[2])]

self.layers_context = [eqx.nn.Linear(context_size, width_size, key=keys[3]), activation,
# eqx.nn.Linear(width_size, width_size, key=keys[11]), activation,
eqx.nn.Linear(width_size, width_size, key=keys[4]), activation,
# self.layers_context = [eqx.nn.Linear(context_size, width_size, key=keys[3]), activation,
# eqx.nn.Linear(width_size, width_size, key=keys[11]), activation,
# eqx.nn.Linear(width_size, width_size, key=keys[4]), activation,
# eqx.nn.Linear(width_size, data_size, key=keys[5])]

self.layers_context = [eqx.nn.Linear(context_size, context_size//10, key=keys[3]), activation,
eqx.nn.Linear(context_size//10, width_size*4, key=keys[11]), activation,
eqx.nn.Linear(width_size*4, width_size, key=keys[4]), activation,
eqx.nn.Linear(width_size, data_size, key=keys[5])]

self.layers_shared = [eqx.nn.Linear(data_size*2, width_size, key=keys[6]), activation,
# eqx.nn.Linear(width_size, width_size, key=keys[7]), activation,
eqx.nn.Linear(width_size, width_size, key=keys[7]), activation,
eqx.nn.Linear(width_size, width_size, key=keys[8]), activation,
eqx.nn.Linear(width_size, data_size, key=keys[9])]

Expand All @@ -129,7 +134,7 @@ def __call__(self, t, x, ctx):

# physics = Physics(key=SEED)
physics = None
augmentation = Augmentation(data_size=2, width_size=8*1, depth=3, context_size=context_size, key=SEED)
augmentation = Augmentation(data_size=2, width_size=8*4, depth=3, context_size=context_size, key=SEED)
contexts = ContextParams(nb_envs, context_size, key=SEED)

# integrator = diffrax.Tsit5()
Expand All @@ -154,15 +159,18 @@ def loss_fn_ctx(model, trajs, t_eval, ctx, alpha, beta, ctx_, key):
trajs_hat, nb_steps = jax.vmap(model, in_axes=(None, None, None, 0))(trajs[:, 0, :], t_eval, ctx, ctx_)
new_trajs = jnp.broadcast_to(trajs, trajs_hat.shape)

# term1 = jnp.mean((new_trajs-trajs_hat)**2)
term1 = jnp.mean((new_trajs-trajs_hat)**2)

weights = jnp.mean((jnp.broadcast_to(ctx, ctx_.shape)-ctx_)**2, axis=-1) + 1e-8
weights = weights / jnp.sum(weights)
term1 = jnp.mean((new_trajs-trajs_hat)**2, axis=(1,2,3)) ## TODO: give more weights to the points for this context itself. Introduce a weighting system
term1 = jnp.sum(term1 * weights)
# weights = jnp.mean((jnp.broadcast_to(ctx, ctx_.shape)-ctx_)**2, axis=-1) + 1e-8
# weights = weights / jnp.sum(weights)
# term1 = jnp.mean((new_trajs-trajs_hat)**2, axis=(1,2,3)) ## TODO: give more weights to the points for this context itself. Introduce a weighting system
# term1 = jnp.sum(term1 * weights)

term2 = 1e-1*jnp.mean((ctx)**2)
loss_val = term1+term2 ### Dangerous, but limit the context TODO
# term3 = 1e-3*spectral_norm_estimation(model.vectorfield.neuralnet, key=key)

# loss_val = term1+term2+term3 ### Dangerous, but limit the context TODO
loss_val = term1+term2

return loss_val, (jnp.sum(nb_steps)/ctx_.shape[0], term1, term2)
#====== New Method ======
Expand All @@ -177,20 +185,19 @@ def loss_fn_ctx(model, trajs, t_eval, ctx, alpha, beta, ctx_, key):

## Define optimiser and traine the model

nb_train_steps = nb_epochs * 2
# sched_node = optax.piecewise_constant_schedule(init_value=3e-3,
# boundaries_and_scales={int(nb_train_steps*0.25):0.2,
# int(nb_train_steps*0.5):0.01,
# int(nb_train_steps*0.75):0.05,
# int(nb_train_steps*0.9):0.2})
sched_node = 1e-3
nb_train_steps = nb_epochs * 11
sched_node = optax.piecewise_constant_schedule(init_value=3e-3,
boundaries_and_scales={int(nb_train_steps*0.25):0.2,
int(nb_train_steps*0.5):0.2,
int(nb_train_steps*0.75):0.2})
# sched_node = 1e-3
# sched_node = optax.exponential_decay(3e-3, nb_epochs*2, 0.99)

# sched_ctx = optax.piecewise_constant_schedule(init_value=3e-2,
# boundaries_and_scales={int(nb_epochs*0.25):0.25,
# int(nb_epochs*0.5):0.25,
# int(nb_epochs*0.75):0.25})
sched_ctx = 1e-3
sched_ctx = optax.piecewise_constant_schedule(init_value=3e-2,
boundaries_and_scales={int(nb_epochs*0.25):0.2,
int(nb_epochs*0.5):0.2,
int(nb_epochs*0.75):0.2})
# sched_ctx = 1e-3

opt_node = optax.adabelief(sched_node)
opt_ctx = optax.adabelief(sched_ctx)
Expand Down Expand Up @@ -247,3 +254,4 @@ def loss_fn_ctx(model, trajs, t_eval, ctx, alpha, beta, ctx_, key):
# train_dataloader.dataset[0,0].shape

# trainer.learner.physics.params
# print(augmentation)
24 changes: 22 additions & 2 deletions nodebias/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,16 @@ def __init__(self, neuralnet, physics=None):
self.physics = physics if physics is not None else ID()

def __call__(self, t, x, ctx, ctx_):
# def __call__(self, t, x, args):
# ctx, ctx_ = args

# print("Shapes of elements:", t.shape, x.shape, ctx.shape, ctx_.shape)

# return self.physics(t, x, ctx) + self.neuralnet(t, x, ctx)

vf = lambda xi_: self.physics(t, x, xi_) + self.neuralnet(t, x, xi_)
gradvf = lambda xi, xi_: eqx.filter_jvp(vf, (xi_,), (xi-xi_,))[1]
return vf(ctx) + gradvf(ctx, ctx_)
gradvf = lambda xi_, xi: eqx.filter_jvp(vf, (xi_,), (xi-xi_,))[1]
return vf(ctx_) + gradvf(ctx_, ctx)
# return vf(ctx)


Expand Down Expand Up @@ -182,6 +184,24 @@ def __call__(self, x0s, t_eval, ctx, ctx_):
batched_ys = jax.vmap(rk4_integrator, in_axes=(None, 0, None))(rhs, x0s, t_eval)
return batched_ys, t_eval.size

# def integrate(x0):
# solution = diffrax.diffeqsolve(
# diffrax.ODETerm(self.vectorfield),
# diffrax.Tsit5(),
# args=(ctx, ctx_),
# t0=t_eval[0],
# t1=t_eval[-1],
# dt0=t_eval[1] - t_eval[0],
# y0=x0,
# stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-4),
# saveat=diffrax.SaveAt(ts=t_eval),
# max_steps=4096*1,
# )
# return solution.ys, solution.stats["num_steps"]

# batched_ys, batched_num_steps = jax.vmap(integrate)(x0s)
# return batched_ys, batched_num_steps




Expand Down
3 changes: 2 additions & 1 deletion nodebias/visualtester.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def test_cf(self, criterion=None, int_cutoff=1.0):

batched_criterion = jax.vmap(jax.vmap(criterion, in_axes=(0, 0)), in_axes=(0, 0))

return batched_criterion(X_hat, X).mean(axis=1).sum(axis=0)
# return batched_criterion(X_hat, X).mean(axis=1).sum(axis=0)
return batched_criterion(X_hat, X).mean(axis=1).mean(axis=0)



Expand Down

0 comments on commit 89e5766

Please sign in to comment.