diff --git a/notebooks/data_gen.ipynb b/notebooks/data_gen.ipynb new file mode 100644 index 0000000..a121f6c --- /dev/null +++ b/notebooks/data_gen.ipynb @@ -0,0 +1,1280 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# LagrangeBench dataset generation\n", + "\n", + "This code was used to generate the 7 datasets from the LagrangeBench paper." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-02-11 03:18:16.315188: W external/xla/xla/service/gpu/nvptx_compiler.cc:679] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.3.107). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n" + ] + } + ], + "source": [ + "from argparse import Namespace\n", + "import os\n", + "\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", + "\n", + "import jax.numpy as jnp\n", + "import h5py\n", + "import numpy as np\n", + "from jax import grad, jit, vmap, random, ops\n", + "from jax_md import space\n", + "from jax_md.partition import Sparse\n", + "from functools import partial\n", + "\n", + "from lagrangebench.case_setup import partition" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "class Kernel:\n", + " \"\"\"The kernel object is a polymorphic base class for all derived kernels.\"\"\"\n", + "\n", + " def __init__(self, h, dim):\n", + " self._dim = dim\n", + " self._h = h\n", + " self._one_over_h = 1.0 / h\n", + "\n", + " def w(self, r):\n", + " \"\"\"Evaluates the kernel at the radial coordinate r.\"\"\"\n", + " return self._w(r)\n", + "\n", + " def grad_w(self, r):\n", + " \"\"\"Evaluates the kernel gradient at the radial coordinate r by utilization of automatic differentation.\"\"\"\n", + " return grad(self.w)(r)\n", + "\n", + "\n", + "class QuinticKernel(Kernel):\n", + " \"\"\"The quintic kernel function of Morris.\"\"\"\n", + "\n", + " def __init__(self, h, dim=3):\n", + " Kernel.__init__(self, h, dim)\n", + " self._normalized_cutoff = 3.0\n", + " self.cutoff = self._normalized_cutoff * h\n", + " self._sigma_1d = 120 * self._one_over_h\n", + " self._sigma_2d = 7.0 / 478.0 / jnp.pi * self._one_over_h * self._one_over_h\n", + " self._sigma_3d = (\n", + " 3.0\n", + " / 359.0\n", + " / jnp.pi\n", + " * self._one_over_h\n", + " * self._one_over_h\n", + " * self._one_over_h\n", + " )\n", + " self._sigma = jnp.where(\n", + " dim == 3,\n", + " self._sigma_3d,\n", + " jnp.where(dim == 2, self._sigma_2d, self._sigma_1d),\n", + " )\n", + "\n", + " def _w(self, r):\n", + " q = r * self._one_over_h\n", + "\n", + " q1 = jnp.maximum(0.0, 1.0 - q)\n", + " q2 = jnp.maximum(0.0, 2.0 - q)\n", + " q3 = jnp.maximum(0.0, 3.0 - q)\n", + "\n", + " return self._sigma * (\n", + " q3 * q3 * q3 * q3 * q3\n", + " - 6.0 * q2 * q2 * q2 * q2 * q2\n", + " + 15.0 * q1 * q1 * q1 * q1 * q1\n", + " )\n", + "\n", + "\n", + "def write_h5(data_dict, path):\n", + " \"\"\"Write a dict of numpy or jax arrays to a .h5 file\"\"\"\n", + " hf = h5py.File(path, \"w\")\n", + " for k, v in data_dict.items():\n", + " hf.create_dataset(k, data=np.array(v))\n", + " hf.close()\n", + "\n", + "\n", + "def pos_init_cartesian_2d(box_size, dx):\n", + " n = np.array((box_size / dx).round(), dtype=int)\n", + " grid = np.meshgrid(range(n[0]), range(n[1]), indexing=\"xy\")\n", + " r = (jnp.vstack(list(map(jnp.ravel, grid))).T + 0.5) * dx\n", + " return r\n", + "\n", + "\n", + "def pos_init_cartesian_3d(box_size, dx):\n", + " n = np.array((box_size / dx).round(), dtype=int)\n", + " grid = np.meshgrid(range(n[0]), range(n[1]), range(n[2]), indexing=\"xy\")\n", + " r = (jnp.vstack(list(map(jnp.ravel, grid))).T + 0.5) * dx\n", + " return r" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def SPH(displacement_fn, g_ext_fn, args):\n", + " \"\"\"Smoothed Particle Hydrodynamics solver\"\"\"\n", + " (\n", + " dx,\n", + " dim,\n", + " dt,\n", + " is_bc_trick,\n", + " is_rho_evol,\n", + " artificial_alpha,\n", + " p_bg_factor,\n", + " u_ref,\n", + " rho_ref,\n", + " mass,\n", + " viscosity,\n", + " ) = (\n", + " args.dx,\n", + " args.dim,\n", + " args.dt,\n", + " args.is_bc_trick,\n", + " args.density_evolution,\n", + " args.artificial_alpha,\n", + " args.p_bg_factor,\n", + " args.u_ref,\n", + " args.rho_ref,\n", + " args.mass,\n", + " args.viscosity,\n", + " )\n", + "\n", + " kernel_fn = QuinticKernel(h=dx, dim=dim)\n", + " c_ref = 10.0 * u_ref\n", + " p_ref = rho_ref * c_ref * c_ref\n", + " p_ref = rho_ref * c_ref * c_ref\n", + " p_background = p_bg_factor * p_ref\n", + "\n", + " @jit\n", + " def _pressure_fn(density, density_reference, speed_of_sound, background_pressure):\n", + " return (\n", + " speed_of_sound * speed_of_sound * (density - density_reference)\n", + " + background_pressure\n", + " )\n", + "\n", + " pressure_fn = partial(\n", + " _pressure_fn,\n", + " density_reference=rho_ref,\n", + " speed_of_sound=c_ref,\n", + " background_pressure=p_background,\n", + " )\n", + "\n", + " @jit\n", + " def _density_fn(pressure, density_reference, speed_of_sound, background_pressure):\n", + " return (\n", + " pressure - background_pressure\n", + " ) / speed_of_sound / speed_of_sound + density_reference\n", + "\n", + " density_fn = partial(\n", + " _density_fn,\n", + " density_reference=rho_ref,\n", + " speed_of_sound=c_ref,\n", + " background_pressure=p_background,\n", + " )\n", + "\n", + " def forward(state, neighbors):\n", + " r, tag, u, dudt, rho, p = (\n", + " state[\"r\"],\n", + " state[\"tag\"],\n", + " state[\"u\"],\n", + " state[\"dudt\"],\n", + " state[\"rho\"],\n", + " state[\"p\"],\n", + " )\n", + " N = len(r)\n", + " i_s, j_s = neighbors.idx\n", + " r_i_s, r_j_s = r[i_s], r[j_s]\n", + " dr_i_j = vmap(displacement_fn)(r_i_s, r_j_s)\n", + " dist = space.distance(dr_i_j)\n", + " w_dist = vmap(kernel_fn.w)(dist)\n", + "\n", + " grad_w_dist = (\n", + " vmap(kernel_fn.grad_w)(dist)[:, None]\n", + " * dr_i_j\n", + " / (dist[:, None] + jnp.finfo(float).eps)\n", + " )\n", + " g_ext = g_ext_fn(r)\n", + "\n", + " # density and pressure\n", + " if is_rho_evol:\n", + " drhodt = rho * ops.segment_sum(\n", + " (mass / rho)[j_s] * ((u[i_s] - u[j_s]) * grad_w_dist).sum(axis=1),\n", + " i_s,\n", + " N,\n", + " )\n", + " rho = rho + dt * drhodt\n", + " else:\n", + " rho = mass * ops.segment_sum(w_dist, i_s, N)\n", + " p = vmap(pressure_fn)(rho)\n", + "\n", + " if is_bc_trick:\n", + " # Based on: \"A generalized wall boundary condition [...]\", Adami, Hu, Adams, 2012\n", + " w_j_s_fluid = w_dist * jnp.where(tag[j_s] == 0, 1.0, 0.0)\n", + " # no-slip boundary condition\n", + " u_wall = ops.segment_sum(w_j_s_fluid[:, None] * u[j_s], i_s, N) / (\n", + " ops.segment_sum(w_j_s_fluid, i_s, N)[:, None] + jnp.finfo(float).eps\n", + " )\n", + " u = jnp.where(tag[:, None] > 0, 2 * u - u_wall, u)\n", + "\n", + " p_wall = (\n", + " ops.segment_sum(w_j_s_fluid * p[j_s], i_s, N)\n", + " + (\n", + " g_ext\n", + " * ops.segment_sum(\n", + " (rho[j_s] * w_j_s_fluid)[:, None] * dr_i_j, i_s, N\n", + " )\n", + " ).sum(axis=1)\n", + " ) / (ops.segment_sum(w_j_s_fluid, i_s, N) + jnp.finfo(float).eps)\n", + " p = jnp.where(tag > 0, p_wall, p)\n", + " rho = vmap(density_fn)(p)\n", + "\n", + " def acceleration_fn(r_ij, d_ij, rho_i, rho_j, u_i, u_j, p_i, p_j):\n", + " p_ij = (rho_j * p_i + rho_i * p_j) / (rho_i + rho_j)\n", + " return (\n", + " ((mass / rho_i) ** 2 + (mass / rho_j) ** 2)\n", + " / mass\n", + " * kernel_fn.grad_w(d_ij)\n", + " / (d_ij + jnp.finfo(float).eps)\n", + " * (-p_ij * r_ij + viscosity * (u_i - u_j))\n", + " )\n", + "\n", + " res = vmap(acceleration_fn)(\n", + " dr_i_j, dist, rho[i_s], rho[j_s], u[i_s], u[j_s], p[i_s], p[j_s]\n", + " )\n", + " dudt = ops.segment_sum(res, i_s, N)\n", + "\n", + " if artificial_alpha != 0.0:\n", + " numerator = (\n", + " mass\n", + " * artificial_alpha\n", + " * dx\n", + " * c_ref\n", + " * ((u[i_s] - u[j_s]) * dr_i_j).sum(axis=1)\n", + " )[:, None] * grad_w_dist\n", + " denominator = ((rho[i_s] + rho[j_s]) / 2 * (dist**2 + 0.01 * dx**2))[\n", + " :, None\n", + " ]\n", + " res = (\n", + " jnp.where((tag[j_s] == 0) * (tag[i_s] == 0), 1.0, 0.0)[:, None]\n", + " * numerator\n", + " / denominator\n", + " )\n", + " dudt_artif = ops.segment_sum(res, i_s, N)\n", + " else:\n", + " dudt_artif = jnp.zeros_like(dudt)\n", + "\n", + " return {\n", + " \"r\": r,\n", + " \"tag\": tag,\n", + " \"u\": u,\n", + " \"dudt\": dudt + g_ext + dudt_artif,\n", + " \"rho\": rho,\n", + " \"p\": p,\n", + " }\n", + "\n", + " return forward\n", + "\n", + "\n", + "@partial(jit, static_argnums=(3, 4, 5))\n", + "def advance(dt, state, neighbors, sph, shift_fn, bc_fn):\n", + " state[\"u\"] += 1.0 * dt * state[\"dudt\"]\n", + " state[\"r\"] = shift_fn(state[\"r\"], 1.0 * dt * state[\"u\"])\n", + " num_particles = (state[\"tag\"] != -1).sum()\n", + " neighbors = neighbors.update(state[\"r\"], num_particles=num_particles)\n", + "\n", + " # update time-derivatives\n", + " state = sph(state, neighbors)\n", + " state = bc_fn(state)\n", + " return state, neighbors" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Particle Relaxations\n", + "\n", + "The trajectories in, e.g., 2D Taylor-Green vortex differ in the initial particle configuration. To get different physical configurations, for each trajectory we run an SPH relaxations of 5000 steps starting from cartesian coordinates plus some Gaussian noise on top. Obtaining a relaxed state for a 2D TGV simulation is demonstrated below." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 / 5000, u_max = 0.0000\n", + "5000 / 5000, u_max = 0.0092\n" + ] + } + ], + "source": [ + "args = Namespace(\n", + " case=\"Rlx\",\n", + " solver=\"SPH\",\n", + " dim=2,\n", + " dx=0.02,\n", + " dt=0.0,\n", + " t_end=0.2,\n", + " box_size=np.array([1.0, 1.0]),\n", + " pbc=np.array([True, True]),\n", + " seed=0,\n", + " write_h5=True,\n", + " write_every=5000,\n", + " r0_noise_factor=0.25,\n", + " viscosity=0.01,\n", + " relax_pbc=True,\n", + " data_path=\"data_relaxed\",\n", + " p_bg_factor=0.0,\n", + " is_bc_trick=False,\n", + " density_evolution=False,\n", + " artificial_alpha=0.0,\n", + " free_slip=False,\n", + " u_ref=1.0,\n", + " rho_ref=1.0,\n", + " sequence_length=5000,\n", + ")\n", + "args.mass = args.dx**args.dim * args.rho_ref\n", + "args.dt = args.dt if args.dt > 0 else 0.25 * args.dx / (11 * args.u_ref)\n", + "\n", + "g_ext_fn = lambda r: jnp.zeros_like(r)\n", + "bc_fn = lambda state: state\n", + "key = random.PRNGKey(args.seed)\n", + "\n", + "displacement_fn, shift_fn = space.periodic(side=args.box_size)\n", + "r_init = pos_init_cartesian_2d(args.box_size, args.dx)\n", + "noise = args.r0_noise_factor * args.dx * random.normal(key, r_init.shape)\n", + "r_init = shift_fn(r_init, noise)\n", + "\n", + "state = {\n", + " \"r\": r_init,\n", + " \"u\": jnp.zeros(r_init.shape),\n", + " \"tag\": jnp.zeros(r_init.shape[0]),\n", + " \"dudt\": jnp.zeros(r_init.shape),\n", + " \"rho\": jnp.ones(r_init.shape[0]),\n", + " \"p\": jnp.zeros(r_init.shape[0]),\n", + "}\n", + "\n", + "# Initialize a neighbor list\n", + "neighbor_fn = partition.neighbor_list(\n", + " displacement_fn,\n", + " args.box_size,\n", + " r_cutoff=3 * args.dx,\n", + " backend=\"jaxmd_vmap\",\n", + " capacity_multiplier=1.25,\n", + " mask_self=False,\n", + " format=Sparse,\n", + " num_particles_max=state[\"r\"].shape[0],\n", + " pbc=args.pbc,\n", + ")\n", + "neighbors = neighbor_fn.allocate(state[\"r\"], num_particles=(state[\"tag\"] != -1).sum())\n", + "\n", + "# create data directory\n", + "os.makedirs(args.data_path, exist_ok=True)\n", + "\n", + "solver = SPH(displacement_fn, g_ext_fn, args)\n", + "for step in range(args.sequence_length + 1):\n", + " if step % args.write_every == 0:\n", + " write_h5(state, os.path.join(args.data_path, f\"state_{step}.h5\"))\n", + " state, neighbors = advance(args.dt, state, neighbors, solver, shift_fn, bc_fn)\n", + "\n", + " # Check whether the edge list is too small and if so, create longer one\n", + " if neighbors.did_buffer_overflow:\n", + " neighbors = neighbor_fn.allocate(\n", + " state[\"r\"], num_particles=(state[\"tag\"] != -1).sum()\n", + " )\n", + "\n", + " if step % args.write_every == 0:\n", + " u_max = jnp.sqrt(jnp.square(state[\"u\"]).sum(axis=1)).max()\n", + " print(f\"{step} / {args.sequence_length}, u_max = {u_max:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2D TGV" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 / 1000, u_max = 1.0000\n", + "100 / 1000, u_max = 0.9785\n", + "200 / 1000, u_max = 0.9655\n", + "300 / 1000, u_max = 0.9315\n", + "400 / 1000, u_max = 0.9077\n", + "500 / 1000, u_max = 0.8852\n", + "600 / 1000, u_max = 0.8356\n", + "700 / 1000, u_max = 0.8147\n", + "800 / 1000, u_max = 0.7770\n", + "900 / 1000, u_max = 0.7494\n", + "1000 / 1000, u_max = 0.7230\n" + ] + } + ], + "source": [ + "args = Namespace(\n", + " case=\"TGV\",\n", + " solver=\"SPH\",\n", + " dim=2,\n", + " dx=0.02,\n", + " dt=0.0004,\n", + " t_end=5,\n", + " box_size=np.array([1.0, 1.0]),\n", + " pbc=np.array([True, True]),\n", + " seed=0,\n", + " write_h5=True,\n", + " write_every=100,\n", + " r0_noise_factor=0.0,\n", + " viscosity=0.01,\n", + " relax_pbc=True,\n", + " data_path=\"datasets/2D_TGV_2500_10kevery100\",\n", + " p_bg_factor=0.0,\n", + " is_bc_trick=False,\n", + " density_evolution=False,\n", + " artificial_alpha=0.0,\n", + " u_ref=1.0,\n", + " rho_ref=1.0,\n", + " sequence_length=1000, # this one should go to 12500\n", + ")\n", + "args.mass = args.dx**args.dim * args.rho_ref\n", + "args.dt = args.dt if args.dt > 0 else 0.25 * args.dx / (11 * args.u_ref)\n", + "\n", + "g_ext_fn = lambda r: jnp.zeros_like(r)\n", + "bc_fn = lambda state: state\n", + "key = random.PRNGKey(args.seed)\n", + "\n", + "# load relaxed state\n", + "r_init = np.array(h5py.File(\"data_relaxed/state_5000.h5\", \"r\")[\"r\"])\n", + "state = {\n", + " \"r\": r_init,\n", + " \"u\": jnp.array(\n", + " [\n", + " -1.0\n", + " * jnp.cos(2.0 * jnp.pi * r_init[:, 0])\n", + " * jnp.sin(2.0 * jnp.pi * r_init[:, 1]),\n", + " +1.0\n", + " * jnp.sin(2.0 * jnp.pi * r_init[:, 0])\n", + " * jnp.cos(2.0 * jnp.pi * r_init[:, 1]),\n", + " ]\n", + " ).T,\n", + " \"tag\": jnp.zeros(r_init.shape[0]),\n", + " \"dudt\": jnp.zeros(r_init.shape),\n", + " \"rho\": jnp.ones(r_init.shape[0]),\n", + " \"p\": jnp.zeros(r_init.shape[0]),\n", + "}\n", + "\n", + "displacement_fn, shift_fn = space.periodic(side=args.box_size)\n", + "noise = args.r0_noise_factor * args.dx * random.normal(key, state[\"r\"].shape)\n", + "state[\"r\"] = shift_fn(state[\"r\"], noise)\n", + "\n", + "# Initialize a neighbor list\n", + "neighbor_fn = partition.neighbor_list(\n", + " displacement_fn,\n", + " args.box_size,\n", + " r_cutoff=3 * args.dx,\n", + " backend=\"jaxmd_vmap\",\n", + " capacity_multiplier=1.25,\n", + " mask_self=False,\n", + " format=Sparse,\n", + " num_particles_max=state[\"r\"].shape[0],\n", + " pbc=args.pbc,\n", + ")\n", + "neighbors = neighbor_fn.allocate(state[\"r\"], num_particles=(state[\"tag\"] != -1).sum())\n", + "\n", + "# create data directory\n", + "os.makedirs(args.data_path, exist_ok=True)\n", + "\n", + "solver = SPH(displacement_fn, g_ext_fn, args)\n", + "for step in range(args.sequence_length + 1):\n", + " if step % args.write_every == 0:\n", + " write_h5(state, os.path.join(args.data_path, f\"state_{step}.h5\"))\n", + " state, neighbors = advance(args.dt, state, neighbors, solver, shift_fn, bc_fn)\n", + "\n", + " # Check whether the edge list is too small and if so, create longer one\n", + " if neighbors.did_buffer_overflow:\n", + " neighbors = neighbor_fn.allocate(\n", + " state[\"r\"], num_particles=(state[\"tag\"] != -1).sum()\n", + " )\n", + "\n", + " if step % args.write_every == 0:\n", + " u_max = jnp.sqrt(jnp.square(state[\"u\"]).sum(axis=1)).max()\n", + " print(f\"{step} / {args.sequence_length}, u_max = {u_max:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3D TGV" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 / 1000, u_max = 0.9638\n", + "100 / 1000, u_max = 0.9366\n", + "200 / 1000, u_max = 0.7744\n", + "300 / 1000, u_max = 0.9897\n", + "400 / 1000, u_max = 0.8363\n", + "500 / 1000, u_max = 0.6221\n", + "600 / 1000, u_max = 0.6590\n", + "700 / 1000, u_max = 0.6121\n", + "800 / 1000, u_max = 0.4958\n", + "900 / 1000, u_max = 0.4383\n", + "1000 / 1000, u_max = 0.3921\n" + ] + } + ], + "source": [ + "args = Namespace(\n", + " case=\"TGV\",\n", + " solver=\"SPH\",\n", + " dim=3,\n", + " dx=0.314159265,\n", + " dt=0.005,\n", + " t_end=30,\n", + " box_size=np.array([2 * np.pi, 2 * np.pi, 2 * np.pi]),\n", + " pbc=np.array([True, True, True]),\n", + " seed=0,\n", + " write_h5=True,\n", + " write_every=100,\n", + " r0_noise_factor=0.0,\n", + " viscosity=0.02,\n", + " relax_pbc=True,\n", + " data_path=\"datasets/3D_TGV_8000_10kevery100\",\n", + " p_bg_factor=0.0,\n", + " is_bc_trick=False,\n", + " density_evolution=False,\n", + " artificial_alpha=0.0,\n", + " u_ref=1.0,\n", + " rho_ref=1.0,\n", + " sequence_length=1000, # this one should go to 6000\n", + ")\n", + "args.mass = args.dx**args.dim * args.rho_ref\n", + "args.dt = args.dt if args.dt > 0 else 0.25 * args.dx / (11 * args.u_ref)\n", + "\n", + "g_ext_fn = lambda r: jnp.zeros_like(r)\n", + "bc_fn = lambda state: state\n", + "key = random.PRNGKey(args.seed)\n", + "\n", + "r_init = pos_init_cartesian_3d(args.box_size, args.dx) # replace with relaxed state\n", + "state = {\n", + " \"r\": r_init,\n", + " \"u\": jnp.array(\n", + " [\n", + " +jnp.sin(r_init[:, 0]) * jnp.cos(r_init[:, 1]) * jnp.cos(r_init[:, 2]),\n", + " -jnp.cos(r_init[:, 0]) * jnp.sin(r_init[:, 1]) * jnp.cos(r_init[:, 2]),\n", + " 0 * r_init[:, 1],\n", + " ]\n", + " ).T,\n", + " \"tag\": jnp.zeros(r_init.shape[0]),\n", + " \"dudt\": jnp.zeros(r_init.shape),\n", + " \"rho\": jnp.ones(r_init.shape[0]),\n", + " \"p\": jnp.zeros(r_init.shape[0]),\n", + "}\n", + "\n", + "displacement_fn, shift_fn = space.periodic(side=args.box_size)\n", + "noise = args.r0_noise_factor * args.dx * random.normal(key, state[\"r\"].shape)\n", + "state[\"r\"] = shift_fn(state[\"r\"], noise)\n", + "\n", + "# Initialize a neighbor list\n", + "neighbor_fn = partition.neighbor_list(\n", + " displacement_fn,\n", + " args.box_size,\n", + " r_cutoff=3 * args.dx,\n", + " backend=\"jaxmd_vmap\",\n", + " capacity_multiplier=1.25,\n", + " mask_self=False,\n", + " format=Sparse,\n", + " num_particles_max=state[\"r\"].shape[0],\n", + " pbc=args.pbc,\n", + ")\n", + "neighbors = neighbor_fn.allocate(state[\"r\"], num_particles=(state[\"tag\"] != -1).sum())\n", + "\n", + "# create data directory\n", + "os.makedirs(args.data_path, exist_ok=True)\n", + "\n", + "solver = SPH(displacement_fn, g_ext_fn, args)\n", + "for step in range(args.sequence_length + 1):\n", + " if step % args.write_every == 0:\n", + " write_h5(state, os.path.join(args.data_path, f\"state_{step}.h5\"))\n", + " state, neighbors = advance(args.dt, state, neighbors, solver, shift_fn, bc_fn)\n", + "\n", + " # Check whether the edge list is too small and if so, create longer one\n", + " if neighbors.did_buffer_overflow:\n", + " neighbors = neighbor_fn.allocate(\n", + " state[\"r\"], num_particles=(state[\"tag\"] != -1).sum()\n", + " )\n", + "\n", + " if step % args.write_every == 0:\n", + " u_max = jnp.sqrt(jnp.square(state[\"u\"]).sum(axis=1)).max()\n", + " print(f\"{step} / {args.sequence_length}, u_max = {u_max:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 2D RPF" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 / 1000, u_max = 0.0000\n", + "100 / 1000, u_max = 0.0500\n", + "200 / 1000, u_max = 0.1000\n", + "300 / 1000, u_max = 0.1500\n", + "400 / 1000, u_max = 0.2000\n", + "500 / 1000, u_max = 0.2500\n", + "600 / 1000, u_max = 0.3000\n", + "700 / 1000, u_max = 0.3500\n", + "800 / 1000, u_max = 0.4000\n", + "900 / 1000, u_max = 0.4500\n", + "1000 / 1000, u_max = 0.5001\n" + ] + } + ], + "source": [ + "args = Namespace(\n", + " case=\"RPF\",\n", + " solver=\"SPH\",\n", + " dim=2,\n", + " dx=0.025,\n", + " dt=0.0005,\n", + " t_end=2050,\n", + " box_size=np.array([1, 2.0]),\n", + " pbc=np.array([True, True]),\n", + " seed=0,\n", + " write_h5=True,\n", + " write_every=100,\n", + " r0_noise_factor=0.0,\n", + " viscosity=0.01,\n", + " relax_pbc=True,\n", + " data_path=\"datasets/2D_RPF_3200_20kevery100\",\n", + " p_bg_factor=0.05,\n", + " is_bc_trick=False,\n", + " density_evolution=False,\n", + " artificial_alpha=0.0,\n", + " u_ref=1.0,\n", + " rho_ref=1.0,\n", + " sequence_length=1000, # this one should go to 4.100.000\n", + ")\n", + "args.mass = args.dx**args.dim * args.rho_ref\n", + "args.dt = args.dt if args.dt > 0 else 0.25 * args.dx / (11 * args.u_ref)\n", + "\n", + "g_ext_fn = lambda r: jnp.where(r[:, 1] > 1.0, -1.0, 1.0)[:, None] * jnp.array(\n", + " [1.0, 0.0]\n", + ")\n", + "bc_fn = lambda state: state\n", + "key = random.PRNGKey(args.seed)\n", + "\n", + "r_init = pos_init_cartesian_2d(args.box_size, args.dx) # replace with relaxed state\n", + "state = {\n", + " \"r\": r_init,\n", + " \"u\": jnp.zeros_like(r_init),\n", + " \"tag\": jnp.zeros(r_init.shape[0]),\n", + " \"dudt\": jnp.zeros_like(r_init),\n", + " \"rho\": jnp.ones(r_init.shape[0]),\n", + " \"p\": jnp.zeros(r_init.shape[0]),\n", + "}\n", + "\n", + "displacement_fn, shift_fn = space.periodic(side=args.box_size)\n", + "\n", + "# Initialize a neighbor list\n", + "neighbor_fn = partition.neighbor_list(\n", + " displacement_fn,\n", + " args.box_size,\n", + " r_cutoff=3 * args.dx,\n", + " backend=\"jaxmd_vmap\",\n", + " capacity_multiplier=1.25,\n", + " mask_self=False,\n", + " format=Sparse,\n", + " num_particles_max=state[\"r\"].shape[0],\n", + " pbc=args.pbc,\n", + ")\n", + "neighbors = neighbor_fn.allocate(state[\"r\"], num_particles=(state[\"tag\"] != -1).sum())\n", + "\n", + "# create data directory\n", + "os.makedirs(args.data_path, exist_ok=True)\n", + "\n", + "solver = SPH(displacement_fn, g_ext_fn, args)\n", + "for step in range(args.sequence_length + 1):\n", + " if step % args.write_every == 0:\n", + " write_h5(state, os.path.join(args.data_path, f\"state_{step}.h5\"))\n", + " state, neighbors = advance(args.dt, state, neighbors, solver, shift_fn, bc_fn)\n", + "\n", + " # Check whether the edge list is too small and if so, create longer one\n", + " if neighbors.did_buffer_overflow:\n", + " neighbors = neighbor_fn.allocate(\n", + " state[\"r\"], num_particles=(state[\"tag\"] != -1).sum()\n", + " )\n", + "\n", + " if step % args.write_every == 0:\n", + " u_max = jnp.sqrt(jnp.square(state[\"u\"]).sum(axis=1)).max()\n", + " print(f\"{step} / {args.sequence_length}, u_max = {u_max:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3D RPF" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 / 1000, u_max = 0.0000\n", + "100 / 1000, u_max = 0.1000\n", + "200 / 1000, u_max = 0.1989\n", + "300 / 1000, u_max = 0.2934\n", + "400 / 1000, u_max = 0.3814\n", + "500 / 1000, u_max = 0.4620\n", + "600 / 1000, u_max = 0.5355\n", + "700 / 1000, u_max = 0.6023\n", + "800 / 1000, u_max = 0.6630\n", + "900 / 1000, u_max = 0.7181\n", + "1000 / 1000, u_max = 0.7682\n" + ] + } + ], + "source": [ + "args = Namespace(\n", + " case=\"RPF\",\n", + " solver=\"SPH\",\n", + " dim=3,\n", + " dx=0.05,\n", + " dt=0.001,\n", + " t_end=2050,\n", + " box_size=np.array([1, 2, 0.5]),\n", + " pbc=np.array([True, True, True]),\n", + " seed=0,\n", + " write_h5=True,\n", + " write_every=100,\n", + " r0_noise_factor=0.0,\n", + " viscosity=0.1,\n", + " relax_pbc=True,\n", + " data_path=\"datasets/3D_RPF_8000_10kevery100\",\n", + " p_bg_factor=0.02,\n", + " is_bc_trick=False,\n", + " density_evolution=False,\n", + " artificial_alpha=0.0,\n", + " u_ref=1.0,\n", + " rho_ref=1.0,\n", + " sequence_length=1000, # this one should go to 2.050.000\n", + ")\n", + "args.mass = args.dx**args.dim * args.rho_ref\n", + "args.dt = args.dt if args.dt > 0 else 0.25 * args.dx / (11 * args.u_ref)\n", + "\n", + "g_ext_fn = lambda r: jnp.where(r[:, 1] > 1.0, -1.0, 1.0)[:, None] * jnp.array(\n", + " [1.0, 0.0, 0.0]\n", + ")\n", + "bc_fn = lambda state: state\n", + "key = random.PRNGKey(args.seed)\n", + "\n", + "r_init = pos_init_cartesian_3d(args.box_size, args.dx) # replace with relaxed state\n", + "state = {\n", + " \"r\": r_init,\n", + " \"u\": jnp.zeros_like(r_init),\n", + " \"tag\": jnp.zeros(r_init.shape[0]),\n", + " \"dudt\": jnp.zeros_like(r_init),\n", + " \"rho\": jnp.ones(r_init.shape[0]),\n", + " \"p\": jnp.zeros(r_init.shape[0]),\n", + "}\n", + "\n", + "displacement_fn, shift_fn = space.periodic(side=args.box_size)\n", + "noise = args.r0_noise_factor * args.dx * random.normal(key, state[\"r\"].shape)\n", + "state[\"r\"] = shift_fn(state[\"r\"], noise)\n", + "\n", + "# Initialize a neighbor list\n", + "neighbor_fn = partition.neighbor_list(\n", + " displacement_fn,\n", + " args.box_size,\n", + " r_cutoff=3 * args.dx,\n", + " backend=\"jaxmd_vmap\",\n", + " capacity_multiplier=1.25,\n", + " mask_self=False,\n", + " format=Sparse,\n", + " num_particles_max=state[\"r\"].shape[0],\n", + " pbc=args.pbc,\n", + ")\n", + "neighbors = neighbor_fn.allocate(state[\"r\"], num_particles=(state[\"tag\"] != -1).sum())\n", + "\n", + "# create data directory\n", + "os.makedirs(args.data_path, exist_ok=True)\n", + "\n", + "solver = SPH(displacement_fn, g_ext_fn, args)\n", + "for step in range(args.sequence_length + 1):\n", + " if step % args.write_every == 0:\n", + " write_h5(state, os.path.join(args.data_path, f\"state_{step}.h5\"))\n", + " state, neighbors = advance(args.dt, state, neighbors, solver, shift_fn, bc_fn)\n", + "\n", + " # Check whether the edge list is too small and if so, create longer one\n", + " if neighbors.did_buffer_overflow:\n", + " neighbors = neighbor_fn.allocate(\n", + " state[\"r\"], num_particles=(state[\"tag\"] != -1).sum()\n", + " )\n", + "\n", + " if step % args.write_every == 0:\n", + " u_max = jnp.sqrt(jnp.square(state[\"u\"]).sum(axis=1)).max()\n", + " print(f\"{step} / {args.sequence_length}, u_max = {u_max:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2D LDC" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 / 1000, u_max = 1.0000\n", + "100 / 1000, u_max = 1.0000\n", + "200 / 1000, u_max = 1.0000\n", + "300 / 1000, u_max = 1.0000\n", + "400 / 1000, u_max = 1.0000\n", + "500 / 1000, u_max = 1.0000\n", + "600 / 1000, u_max = 1.0000\n", + "700 / 1000, u_max = 1.0000\n", + "800 / 1000, u_max = 1.0000\n", + "900 / 1000, u_max = 1.0000\n", + "1000 / 1000, u_max = 1.0000\n" + ] + } + ], + "source": [ + "args = Namespace(\n", + " case=\"LDC\",\n", + " solver=\"SPH\",\n", + " dim=2,\n", + " dx=0.02,\n", + " dt=0.0004,\n", + " t_end=850,\n", + " box_size=np.array([1.12, 1.12]),\n", + " pbc=np.array([True, True]),\n", + " seed=0,\n", + " write_h5=True,\n", + " write_every=100,\n", + " r0_noise_factor=0.0,\n", + " viscosity=0.01,\n", + " relax_pbc=True,\n", + " data_path=\"datasets/2D_LDC_2708_10kevery100\",\n", + " p_bg_factor=0.01,\n", + " is_bc_trick=True,\n", + " density_evolution=False,\n", + " artificial_alpha=0.0,\n", + " u_ref=1.0,\n", + " rho_ref=1.0,\n", + " sequence_length=1000, # this one should go to 2.125.000\n", + ")\n", + "args.mass = args.dx**args.dim * args.rho_ref\n", + "args.dt = args.dt if args.dt > 0 else 0.25 * args.dx / (11 * args.u_ref)\n", + "\n", + "g_ext_fn = lambda r: jnp.zeros_like(r)\n", + "\n", + "\n", + "def bc_fn(state):\n", + " u_lid = jnp.array([1.0, 0.0])\n", + " state[\"u\"] = jnp.where(state[\"tag\"][:, None] == 1, 0, state[\"u\"])\n", + " state[\"u\"] = jnp.where(state[\"tag\"][:, None] == 2, u_lid, state[\"u\"])\n", + " return state\n", + "\n", + "\n", + "key = random.PRNGKey(args.seed)\n", + "\n", + "r_init = pos_init_cartesian_2d(args.box_size, args.dx) # replace with relaxed state\n", + "# tags: {'0': water, '1': solid wall, '2': moving wall}\n", + "tag = jnp.ones(len(r_init), dtype=int)\n", + "tag = jnp.where(\n", + " jnp.where(jnp.abs(r_init - r_init.mean(axis=0)).max(axis=1) < 0.5, True, False),\n", + " 0,\n", + " tag,\n", + ")\n", + "tag = jnp.where(jnp.where(r_init[:, 1] > 1 + 3 * args.dx, True, False), 2, tag)\n", + "state = {\n", + " \"r\": r_init,\n", + " \"u\": jnp.zeros_like(r_init),\n", + " \"tag\": tag,\n", + " \"dudt\": jnp.zeros_like(r_init),\n", + " \"rho\": jnp.ones(r_init.shape[0]),\n", + " \"p\": jnp.zeros(r_init.shape[0]),\n", + "}\n", + "\n", + "displacement_fn, shift_fn = space.periodic(side=args.box_size)\n", + "\n", + "# Initialize a neighbor list\n", + "neighbor_fn = partition.neighbor_list(\n", + " displacement_fn,\n", + " args.box_size,\n", + " r_cutoff=3 * args.dx,\n", + " backend=\"jaxmd_vmap\",\n", + " capacity_multiplier=1.25,\n", + " mask_self=False,\n", + " format=Sparse,\n", + " num_particles_max=state[\"r\"].shape[0],\n", + " pbc=args.pbc,\n", + ")\n", + "neighbors = neighbor_fn.allocate(state[\"r\"], num_particles=(state[\"tag\"] != -1).sum())\n", + "\n", + "# create data directory\n", + "os.makedirs(args.data_path, exist_ok=True)\n", + "\n", + "solver = SPH(displacement_fn, g_ext_fn, args)\n", + "for step in range(args.sequence_length + 1):\n", + " if step % args.write_every == 0:\n", + " write_h5(state, os.path.join(args.data_path, f\"state_{step}.h5\"))\n", + " state, neighbors = advance(args.dt, state, neighbors, solver, shift_fn, bc_fn)\n", + "\n", + " # Check whether the edge list is too small and if so, create longer one\n", + " if neighbors.did_buffer_overflow:\n", + " neighbors = neighbor_fn.allocate(\n", + " state[\"r\"], num_particles=(state[\"tag\"] != -1).sum()\n", + " )\n", + "\n", + " if step % args.write_every == 0:\n", + " u_max = jnp.sqrt(jnp.square(state[\"u\"]).sum(axis=1)).max()\n", + " print(f\"{step} / {args.sequence_length}, u_max = {u_max:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3D LDC" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 / 1000, u_max = 1.0000\n", + "100 / 1000, u_max = 1.0000\n", + "200 / 1000, u_max = 1.0000\n", + "300 / 1000, u_max = 1.0000\n", + "400 / 1000, u_max = 1.0000\n", + "500 / 1000, u_max = 1.0000\n", + "600 / 1000, u_max = 1.0000\n", + "700 / 1000, u_max = 1.0000\n", + "800 / 1000, u_max = 1.0000\n", + "900 / 1000, u_max = 1.0000\n", + "1000 / 1000, u_max = 1.0000\n" + ] + } + ], + "source": [ + "args = Namespace(\n", + " case=\"LDC\",\n", + " solver=\"SPH\",\n", + " dim=3,\n", + " dx=0.041666667,\n", + " dt=0.0009,\n", + " t_end=1850,\n", + " box_size=np.array([1.25, 1.25, 0.5]),\n", + " pbc=np.array([True, True, True]),\n", + " seed=0,\n", + " write_h5=True,\n", + " write_every=100,\n", + " r0_noise_factor=0.0,\n", + " viscosity=0.01,\n", + " relax_pbc=True,\n", + " data_path=\"datasets/3D_LDC_8160_10kevery100\",\n", + " p_bg_factor=0.01,\n", + " is_bc_trick=True,\n", + " density_evolution=False,\n", + " artificial_alpha=0.0,\n", + " u_ref=1.0,\n", + " rho_ref=1.0,\n", + " sequence_length=1000, # this one should go to 2.125.000\n", + ")\n", + "args.mass = args.dx**args.dim * args.rho_ref\n", + "args.dt = args.dt if args.dt > 0 else 0.25 * args.dx / (11 * args.u_ref)\n", + "\n", + "g_ext_fn = lambda r: jnp.zeros_like(r)\n", + "\n", + "\n", + "def bc_fn(state):\n", + " u_lid = jnp.array([1.0, 0.0, 0.0])\n", + " state[\"u\"] = jnp.where(state[\"tag\"][:, None] == 1, 0, state[\"u\"])\n", + " state[\"u\"] = jnp.where(state[\"tag\"][:, None] == 2, u_lid, state[\"u\"])\n", + " return state\n", + "\n", + "\n", + "key = random.PRNGKey(args.seed)\n", + "\n", + "r_init = pos_init_cartesian_3d(args.box_size, args.dx) # replace with relaxed state\n", + "# tags: {'0': water, '1': solid wall, '2': moving wall}\n", + "tag = jnp.ones(len(r_init), dtype=int)\n", + "tag = jnp.where(\n", + " jnp.where(jnp.abs(r_init - r_init.mean(axis=0)).max(axis=1) < 0.5, True, False),\n", + " 0,\n", + " tag,\n", + ")\n", + "tag = jnp.where(jnp.where(r_init[:, 1] > 1 + 3 * args.dx, True, False), 2, tag)\n", + "state = {\n", + " \"r\": r_init,\n", + " \"u\": jnp.zeros_like(r_init),\n", + " \"tag\": tag,\n", + " \"dudt\": jnp.zeros_like(r_init),\n", + " \"rho\": jnp.ones(r_init.shape[0]),\n", + " \"p\": jnp.zeros(r_init.shape[0]),\n", + "}\n", + "\n", + "displacement_fn, shift_fn = space.periodic(side=args.box_size)\n", + "\n", + "# Initialize a neighbor list\n", + "neighbor_fn = partition.neighbor_list(\n", + " displacement_fn,\n", + " args.box_size,\n", + " r_cutoff=3 * args.dx,\n", + " backend=\"jaxmd_vmap\",\n", + " capacity_multiplier=1.25,\n", + " mask_self=False,\n", + " format=Sparse,\n", + " num_particles_max=state[\"r\"].shape[0],\n", + " pbc=args.pbc,\n", + ")\n", + "neighbors = neighbor_fn.allocate(state[\"r\"], num_particles=(state[\"tag\"] != -1).sum())\n", + "\n", + "# create data directory\n", + "os.makedirs(args.data_path, exist_ok=True)\n", + "\n", + "solver = SPH(displacement_fn, g_ext_fn, args)\n", + "for step in range(args.sequence_length + 1):\n", + " if step % args.write_every == 0:\n", + " write_h5(state, os.path.join(args.data_path, f\"state_{step}.h5\"))\n", + " state, neighbors = advance(args.dt, state, neighbors, solver, shift_fn, bc_fn)\n", + "\n", + " # Check whether the edge list is too small and if so, create longer one\n", + " if neighbors.did_buffer_overflow:\n", + " neighbors = neighbor_fn.allocate(\n", + " state[\"r\"], num_particles=(state[\"tag\"] != -1).sum()\n", + " )\n", + "\n", + " if step % args.write_every == 0:\n", + " u_max = jnp.sqrt(jnp.square(state[\"u\"]).sum(axis=1)).max()\n", + " print(f\"{step} / {args.sequence_length}, u_max = {u_max:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2D DAM" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 / 1000, u_max = 0.0000\n", + "100 / 1000, u_max = 1.8534\n", + "200 / 1000, u_max = 0.8289\n", + "300 / 1000, u_max = 0.7241\n", + "400 / 1000, u_max = 0.8026\n", + "500 / 1000, u_max = 0.7806\n", + "600 / 1000, u_max = 0.7751\n", + "700 / 1000, u_max = 0.7977\n", + "800 / 1000, u_max = 0.8188\n", + "900 / 1000, u_max = 0.8487\n", + "1000 / 1000, u_max = 0.8456\n" + ] + } + ], + "source": [ + "args = Namespace(\n", + " case=\"DAM\",\n", + " solver=\"SPH\",\n", + " dim=2,\n", + " dx=0.02,\n", + " dt=0.0003,\n", + " t_end=12,\n", + " box_size=np.array([5.486, 2.12]),\n", + " pbc=np.array([True, True]),\n", + " seed=0,\n", + " write_h5=True,\n", + " write_every=100,\n", + " r0_noise_factor=0.0,\n", + " viscosity=0.00005,\n", + " relax_pbc=True,\n", + " data_path=\"datasets/3D_LDC_5740_20kevery100\",\n", + " p_bg_factor=0.01,\n", + " is_bc_trick=True,\n", + " density_evolution=False,\n", + " artificial_alpha=0.1,\n", + " u_ref=2**0.5,\n", + " rho_ref=1.0,\n", + " sequence_length=1000, # this one should go to 40.000\n", + ")\n", + "args.mass = args.dx**args.dim * args.rho_ref\n", + "args.dt = args.dt if args.dt > 0 else 0.25 * args.dx / (11 * args.u_ref)\n", + "\n", + "g_ext_fn = lambda r: jnp.ones_like(r) * jnp.array([0.0, 1.0])\n", + "\n", + "\n", + "def bc_fn(state):\n", + " state[\"u\"] = jnp.where(state[\"tag\"][:, None] == 1, 0, state[\"u\"])\n", + " return state\n", + "\n", + "\n", + "key = random.PRNGKey(args.seed)\n", + "\n", + "\n", + "def init_dam():\n", + " L_wall = 5.366\n", + " H_wall = 2.0\n", + " L = 2.0\n", + " H = 1.0\n", + " dx = args.dx\n", + " dx3 = 3 * args.dx\n", + " dx6 = 6 * args.dx\n", + "\n", + " r_fluid = dx3 + pos_init_cartesian_2d(np.array([L, H]), dx)\n", + " # horizontal and vertical blocks\n", + " vertical = pos_init_cartesian_2d(np.array([dx3, H_wall + dx6]), dx)\n", + " horiz = pos_init_cartesian_2d(np.array([L_wall, dx3]), dx)\n", + " # wall: left, bottom, right, top\n", + " wall_l = vertical.copy()\n", + " wall_b = horiz.copy() + np.array([dx3, 0.0])\n", + " wall_r = vertical.copy() + np.array([L_wall + dx3, 0.0])\n", + " wall_t = horiz.copy() + np.array([dx3, H_wall + dx3])\n", + "\n", + " res = np.concatenate([wall_l, wall_b, wall_r, wall_t, r_fluid])\n", + " # tag the walls as \"1\" and the fluid as \"0\"\n", + " tag = np.ones(len(wall_l) + len(wall_b) + len(wall_r) + len(wall_t))\n", + " tag = np.concatenate([tag, np.zeros(len(r_fluid))])\n", + " return res, tag\n", + "\n", + "\n", + "r_init, tag = init_dam() # replace with relaxed fluid state\n", + "state = {\n", + " \"r\": r_init,\n", + " \"u\": jnp.zeros_like(r_init),\n", + " \"tag\": tag,\n", + " \"dudt\": jnp.zeros_like(r_init),\n", + " \"rho\": jnp.ones(r_init.shape[0]),\n", + " \"p\": jnp.zeros(r_init.shape[0]),\n", + "}\n", + "\n", + "displacement_fn, shift_fn = space.periodic(side=args.box_size)\n", + "\n", + "# Initialize a neighbor list\n", + "neighbor_fn = partition.neighbor_list(\n", + " displacement_fn,\n", + " args.box_size,\n", + " r_cutoff=3 * args.dx,\n", + " backend=\"jaxmd_vmap\",\n", + " capacity_multiplier=1.25,\n", + " mask_self=False,\n", + " format=Sparse,\n", + " num_particles_max=state[\"r\"].shape[0],\n", + " pbc=args.pbc,\n", + ")\n", + "neighbors = neighbor_fn.allocate(state[\"r\"], num_particles=(state[\"tag\"] != -1).sum())\n", + "\n", + "# create data directory\n", + "os.makedirs(args.data_path, exist_ok=True)\n", + "\n", + "solver = SPH(displacement_fn, g_ext_fn, args)\n", + "for step in range(args.sequence_length + 1):\n", + " if step % args.write_every == 0:\n", + " write_h5(state, os.path.join(args.data_path, f\"state_{step}.h5\"))\n", + " state, neighbors = advance(args.dt, state, neighbors, solver, shift_fn, bc_fn)\n", + "\n", + " # Check whether the edge list is too small and if so, create longer one\n", + " if neighbors.did_buffer_overflow:\n", + " neighbors = neighbor_fn.allocate(\n", + " state[\"r\"], num_particles=(state[\"tag\"] != -1).sum()\n", + " )\n", + "\n", + " if step % args.write_every == 0:\n", + " u_max = jnp.sqrt(jnp.square(state[\"u\"]).sum(axis=1)).max()\n", + " print(f\"{step} / {args.sequence_length}, u_max = {u_max:.4f}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}