Skip to content

Commit

Permalink
Merge pull request #20 from tumaer/batched_rollout
Browse files Browse the repository at this point in the history
Batched rollout
  • Loading branch information
arturtoshev authored Jan 9, 2024
2 parents af8d7cd + 5a7a0fc commit 45f5900
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 64 deletions.
3 changes: 0 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@ repos:
- id: check-docstring-first
- id: check-json
- id: check-toml
- id: check-xml
- id: check-yaml
- id: trailing-whitespace
- id: end-of-file-fixer
- id: requirements-txt-fixer
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: 'v0.1.8'
Expand Down
64 changes: 37 additions & 27 deletions lagrangebench/evaluate/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
from lagrangebench.data.utils import numpy_collate
from lagrangebench.defaults import defaults
from lagrangebench.evaluate.metrics import MetricsComputer, MetricsDict
from lagrangebench.evaluate.utils import write_vtk
from lagrangebench.utils import (
broadcast_from_batch,
broadcast_to_batch,
get_kinematic_mask,
load_haiku,
set_seed,
write_vtk,
)


Expand Down Expand Up @@ -75,27 +75,28 @@ def _forward_eval(


def eval_batched_rollout(
model_apply: Callable,
forward_eval_vmap: Callable,
preprocess_eval_vmap: Callable,
case,
params: hk.Params,
state: hk.State,
traj_batch_i: Tuple[jnp.ndarray, jnp.ndarray],
neighbors: partition.NeighborList,
metrics_computer: MetricsComputer,
metrics_computer_vmap: Callable,
n_rollout_steps: int,
t_window: int,
n_extrap_steps: int = 0,
) -> Tuple[jnp.ndarray, MetricsDict, jnp.ndarray]:
"""Compute the rollout on a single trajectory.
Args:
model_apply: Model function.
forward_eval_vmap: Model function.
case: CaseSetupFn class.
params: Haiku params.
state: Haiku state.
traj_batch_i: Trajectory to evaluate.
neighbors: Neighbor list.
metrics_computer: MetricsComputer with the desired metrics.
metrics_computer: Vectorized MetricsComputer with the desired metrics.
n_rollout_steps: Number of rollout steps.
t_window: Length of the input sequence.
n_extrap_steps: Number of extrapolation steps (beyond the ground truth rollout).
Expand All @@ -105,7 +106,8 @@ def eval_batched_rollout(
"""
# particle type is treated as a static property defined by state at t=0
pos_input_batch, particle_type_batch = traj_batch_i
batch_size, n_nodes_max, _, dim = pos_input_batch.shape
# current_batch_size might be < eval_batch_size if the last batch is not full
current_batch_size, n_nodes_max, _, dim = pos_input_batch.shape

# if n_rollout_steps set to -1, use the whole trajectory
if n_rollout_steps == -1:
Expand All @@ -116,16 +118,8 @@ def eval_batched_rollout(
traj_len = n_rollout_steps + n_extrap_steps
target_positions_batch = pos_input_batch[:, :, t_window : t_window + traj_len]

predictions_batch = jnp.zeros((batch_size, traj_len, n_nodes_max, dim))
neighbors_batch = broadcast_to_batch(neighbors, batch_size)
preprocess_eval_vmap = vmap(case.preprocess_eval, in_axes=(0, 0))

forward_eval = partial(
_forward_eval,
model_apply=model_apply,
case_integrate=case.integrate,
)
forward_eval_vmap = vmap(forward_eval, in_axes=(None, None, 0, 0, 0))
predictions_batch = jnp.zeros((current_batch_size, traj_len, n_nodes_max, dim))
neighbors_batch = broadcast_to_batch(neighbors, current_batch_size)

step = 0
while step < n_rollout_steps + n_extrap_steps:
Expand All @@ -149,11 +143,10 @@ def eval_batched_rollout(
print(
f"(eval) From {neighbors_batch.idx[ind].shape} to {nbrs_temp.idx.shape}"
)
neighbors_batch = broadcast_to_batch(nbrs_temp, batch_size)
neighbors_batch = broadcast_to_batch(nbrs_temp, current_batch_size)

# To run the loop N times even if sometimes
# did_buffer_overflow > 0 we directly return to the beginning

continue

# 3. run forward model
Expand All @@ -176,7 +169,7 @@ def eval_batched_rollout(

# (batch, n_nodes, time, dim) -> (batch, time, n_nodes, dim)
target_positions_batch = target_positions_batch.transpose(0, 2, 1, 3)
metrics_batch = vmap(metrics_computer)(predictions_batch, target_positions_batch)
metrics_batch = metrics_computer_vmap(predictions_batch, target_positions_batch)

return (predictions_batch, metrics_batch, broadcast_from_batch(neighbors_batch, 0))

Expand Down Expand Up @@ -221,26 +214,42 @@ def eval_rollout(
if rollout_dir is not None:
os.makedirs(rollout_dir, exist_ok=True)

forward_eval = partial(
_forward_eval,
model_apply=model_apply,
case_integrate=case.integrate,
)
forward_eval_vmap = vmap(forward_eval, in_axes=(None, None, 0, 0, 0))
preprocess_eval_vmap = vmap(case.preprocess_eval, in_axes=(0, 0))
metrics_computer_vmap = vmap(metrics_computer, in_axes=(0, 0))

for i, traj_batch_i in enumerate(loader_eval):
# if n_trajs is not a multiple of batch_size, we slice from the last batch
n_traj_left = n_trajs - i * batch_size
if n_traj_left < batch_size:
traj_batch_i = jax.tree_map(lambda x: x[:n_traj_left], traj_batch_i)

# numpy to jax
traj_batch_i = jax.tree_map(lambda x: jnp.array(x), traj_batch_i)
# (pos_input_batch, particle_type_batch) = traj_batch_i
# pos_input_batch.shape = (batch, num_particles, seq_length, dim)

example_rollout_batch, metrics_batch, neighbors = eval_batched_rollout(
model_apply=model_apply,
forward_eval_vmap=forward_eval_vmap,
preprocess_eval_vmap=preprocess_eval_vmap,
case=case,
params=params,
state=state,
traj_batch_i=traj_batch_i, # (batch, nodes, t, dim)
neighbors=neighbors,
metrics_computer=metrics_computer,
metrics_computer_vmap=metrics_computer_vmap,
n_rollout_steps=n_rollout_steps,
t_window=t_window,
n_extrap_steps=n_extrap_steps,
)

for j in range(batch_size):
current_batch_size = traj_batch_i[0].shape[0]
for j in range(current_batch_size):
# write metrics to output dictionary
ind = i * batch_size + j
eval_metrics[f"rollout_{ind}"] = broadcast_from_batch(metrics_batch, j)
Expand All @@ -249,30 +258,31 @@ def eval_rollout(
# (batch, nodes, t, dim) -> (batch, t, nodes, dim)
pos_input_batch = traj_batch_i[0].transpose(0, 2, 1, 3)

for j in range(batch_size): # write every trajectory to file
for j in range(current_batch_size): # write every trajectory to file
pos_input = pos_input_batch[j]
example_rollout = example_rollout_batch[j]

initial_positions = pos_input[:t_window]
example_full = jnp.concatenate([initial_positions, example_rollout])
example_rollout = {
"predicted_rollout": example_full, # (t, nodes, dim)
"ground_truth_rollout": pos_input, # (t, nodes, dim)
"ground_truth_rollout": pos_input, # (t, nodes, dim),
"particle_type": traj_batch_i[1][j], # (nodes,)
}

file_prefix = f"{rollout_dir}/rollout_{i*batch_size+j}"
file_prefix = os.path.join(rollout_dir, f"rollout_{i*batch_size+j}")
if out_type == "vtk": # write vtk files for each time step
for k in range(pos_input.shape[0]):
# predictions
state_vtk = {
"r": example_rollout["predicted_rollout"][k],
"tag": traj_batch_i[1][j],
"tag": example_rollout["particle_type"],
}
write_vtk(state_vtk, f"{file_prefix}_{k}.vtk")
# ground truth reference
state_vtk = {
"r": example_rollout["ground_truth_rollout"][k],
"tag": traj_batch_i[1][j],
"tag": example_rollout["particle_type"],
}
write_vtk(state_vtk, f"{file_prefix}_ref_{k}.vtk")
if out_type == "pkl":
Expand Down
76 changes: 76 additions & 0 deletions lagrangebench/evaluate/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Utility functions for evaluation."""

import os
import pickle

import numpy as np


def write_vtk(data_dict, path):
"""Store a .vtk file for ParaView."""

try:
import pyvista
except ImportError:
raise ImportError("Please install pyvista to write VTK files.")

r = np.asarray(data_dict["r"])
N, dim = r.shape

# PyVista treats the position information differently than the rest
if dim == 2:
r = np.hstack([r, np.zeros((N, 1))])
data_pv = pyvista.PolyData(r)

# copy all the other information also to pyvista, using plain numpy arrays
for k, v in data_dict.items():
# skip r because we already considered it above
if k == "r":
continue

# working in 3D or scalar features do not require special care
if dim == 2 and v.ndim == 2:
v = np.hstack([v, np.zeros((N, 1))])

data_pv[k] = np.asarray(v)

data_pv.save(path)


def pkl2vtk(src_path, dst_path=None):
"""Convert a rollout pickle file to a set of vtk files.
Args:
src_path (str): Source path to .pkl file.
dst_path (str, optoinal): Destination directory path. Defaults to None.
If None, then the vtk files are saved in the same directory as the pkl file.
Example:
pkl2vtk("rollout/test/rollout_0.pkl", "rollout/test_vtk")
will create files rollout_0_0.vtk, rollout_0_1.vtk, etc. in the directory
"rollout/test_vtk"
"""

# set up destination directory
if dst_path is None:
dst_path = os.path.dirname(src_path)
os.makedirs(dst_path, exist_ok=True)

# load rollout
with open(src_path, "rb") as f:
rollout = pickle.load(f)

file_prefix = os.path.join(dst_path, os.path.basename(src_path).split(".")[0])
for k in range(rollout["predicted_rollout"].shape[0]):
# predictions
state_vtk = {
"r": rollout["predicted_rollout"][k],
"tag": rollout["particle_type"],
}
write_vtk(state_vtk, f"{file_prefix}_{k}.vtk")
# ground truth reference
state_vtk = {
"r": rollout["ground_truth_rollout"][k],
"tag": rollout["particle_type"],
}
write_vtk(state_vtk, f"{file_prefix}_ref_{k}.vtk")
31 changes: 0 additions & 31 deletions lagrangebench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,37 +143,6 @@ def print_params_shapes(params, prefix=""):
print_params_shapes(v, prefix=prefix + k)


def write_vtk(data_dict, path):
"""Store a .vtk file for ParaView."""

try:
import pyvista
except ImportError:
raise ImportError("Please install pyvista to write VTK files.")

r = np.asarray(data_dict["r"])
N, dim = r.shape

# PyVista treats the position information differently than the rest
if dim == 2:
r = np.hstack([r, np.zeros((N, 1))])
data_pv = pyvista.PolyData(r)

# copy all the other information also to pyvista, using plain numpy arrays
for k, v in data_dict.items():
# skip r because we already considered it above
if k == "r":
continue

# working in 3D or scalar features do not require special care
if dim == 2 and v.ndim == 2:
v = np.hstack([v, np.zeros((N, 1))])

data_pv[k] = np.asarray(v)

data_pv.save(path)


def set_seed(seed: int) -> Tuple[jax.Array, Callable, torch.Generator]:
"""Set seeds for jax, random and torch."""
# first PRNG key
Expand Down
17 changes: 14 additions & 3 deletions tests/rollout_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest
from argparse import Namespace
from functools import partial

import haiku as hk
import jax
Expand All @@ -16,7 +17,7 @@
from lagrangebench.data import H5Dataset
from lagrangebench.data.utils import get_dataset_stats, numpy_collate
from lagrangebench.evaluate import MetricsComputer
from lagrangebench.evaluate.rollout import eval_batched_rollout
from lagrangebench.evaluate.rollout import _forward_eval, eval_batched_rollout
from lagrangebench.utils import broadcast_from_batch


Expand Down Expand Up @@ -127,14 +128,24 @@ def model(x):
isl,
)

example_rollout_batch, metrics_batch, neighbors = eval_batched_rollout(
forward_eval = partial(
_forward_eval,
model_apply=model_apply,
case_integrate=self.case.integrate,
)
forward_eval_vmap = vmap(forward_eval, in_axes=(None, None, 0, 0, 0))
preprocess_eval_vmap = vmap(self.case.preprocess_eval, in_axes=(0, 0))
metrics_computer_vmap = vmap(metrics_computer, in_axes=(0, 0))

example_rollout_batch, metrics_batch, neighbors = eval_batched_rollout(
forward_eval_vmap=forward_eval_vmap,
preprocess_eval_vmap=preprocess_eval_vmap,
case=self.case,
params=params,
state=state,
traj_batch_i=traj_batch_i,
neighbors=neighbors,
metrics_computer=metrics_computer,
metrics_computer_vmap=metrics_computer_vmap,
n_rollout_steps=self.config.n_rollout_steps,
t_window=isl,
)
Expand Down

0 comments on commit 45f5900

Please sign in to comment.