Skip to content

Commit

Permalink
Update Particle Container to Pure SoA (#348)
Browse files Browse the repository at this point in the history
Transition particle containers to pure SoA layouts.
  • Loading branch information
ax3l authored Feb 9, 2024
1 parent 7259a22 commit 9876a9e
Show file tree
Hide file tree
Showing 51 changed files with 1,015 additions and 795 deletions.
2 changes: 1 addition & 1 deletion cmake/dependencies/ABLASTR.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ set(ImpactX_openpmd_src ""
set(ImpactX_ablastr_repo "https://github.com/ECP-WarpX/WarpX.git"
CACHE STRING
"Repository URI to pull and build ABLASTR from if(ImpactX_ablastr_internal)")
set(ImpactX_ablastr_branch "24.02"
set(ImpactX_ablastr_branch "11aabdca56335c5ae1cbb2257b8abd6c8f04a67c"
CACHE STRING
"Repository branch for ImpactX_ablastr_repo if(ImpactX_ablastr_internal)")

Expand Down
2 changes: 1 addition & 1 deletion cmake/dependencies/pyAMReX.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ option(ImpactX_pyamrex_internal "Download & build pyAMReX" ON)
set(ImpactX_pyamrex_repo "https://github.com/AMReX-Codes/pyamrex.git"
CACHE STRING
"Repository URI to pull and build pyamrex from if(ImpactX_pyamrex_internal)")
set(ImpactX_pyamrex_branch "24.02"
set(ImpactX_pyamrex_branch "5aa700de18a61f933cb435adbe2299d74d794d6b"
CACHE STRING
"Repository branch for ImpactX_pyamrex_repo if(ImpactX_pyamrex_internal)")

Expand Down
2 changes: 1 addition & 1 deletion examples/epac2004_benchmarks/input_fodo_rf_SC.in
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,4 @@ geometry.prob_relative = 4.0
###############################################################################
# Diagnostics
###############################################################################
diag.slice_step_diagnostics = true
diag.slice_step_diagnostics = false
22 changes: 11 additions & 11 deletions examples/fodo/run_fodo_programmable.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,16 @@ def my_drift(pge, pti, refpart):

else:
array = np.array
# access AoS data such as positions and cpu/id
aos = pti.aos()
aos_arr = array(aos, copy=False)

# access SoA data such as momentum
# access particle attributes
soa = pti.soa()
real_arrays = soa.GetRealData()
px = array(real_arrays[0], copy=False)
py = array(real_arrays[1], copy=False)
pt = array(real_arrays[2], copy=False)
real_arrays = soa.get_real_data()
x = array(real_arrays[0], copy=False)
y = array(real_arrays[1], copy=False)
t = array(real_arrays[2], copy=False)
px = array(real_arrays[3], copy=False)
py = array(real_arrays[4], copy=False)
pt = array(real_arrays[5], copy=False)

# length of the current slice
slice_ds = pge.ds / pge.nslice
Expand All @@ -96,9 +96,9 @@ def my_drift(pge, pti, refpart):
betgam2 = pt_ref**2 - 1.0

# advance position and momentum (drift)
aos_arr[:]["x"] += slice_ds * px[:]
aos_arr[:]["y"] += slice_ds * py[:]
aos_arr[:]["z"] += (slice_ds / betgam2) * pt[:]
x[:] += slice_ds * px[:]
y[:] += slice_ds * py[:]
t[:] += (slice_ds / betgam2) * pt[:]


def my_ref_drift(pge, refpart):
Expand Down
102 changes: 73 additions & 29 deletions examples/pytorch_surrogate_model/run_ml_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@
from urllib import request

import numpy as np

try:
import cupy as cp

cupy_available = True
except ImportError:
cupy_available = False

from surrogate_model_definitions import surrogate_model

try:
Expand All @@ -20,14 +28,34 @@
sys.exit(0)

from impactx import (
Config,
CoordSystem,
ImpactX,
ImpactXParIter,
TransformationDirection,
coordinate_transformation,
distribution,
elements,
)

# CPU/GPU logic
if Config.have_gpu:
if cupy_available:
array = cp.array
stack = cp.stack
device = torch.device("cuda")
else:
print("Warning: GPU found but cupy not available! Try managed...")
array = np.array
stack = np.stack
device = torch.device("cpu")
if Config.gpu_backend == "SYCL":
print("Warning: SYCL GPU backend not yet implemented for Python")

else:
array = np.array
stack = np.stack
device = torch.device("cpu")


def download_and_unzip(url, data_dir):
request.urlretrieve(url, data_dir)
Expand All @@ -50,6 +78,7 @@ def download_and_unzip(url, data_dir):
surrogate_model(
dataset_dir + f"dataset_beam_stage_{i}.pt",
model_dir + f"beam_stage_{i}_model.pt",
device=device,
)
for i in range(N_stage)
]
Expand Down Expand Up @@ -78,47 +107,62 @@ def __init__(self, stage_i, surrogate_model, surrogate_length, stage_start):
self.ds = surrogate_length

def surrogate_push(self, pc, step):
array = np.array

ref_part = pc.ref_particle()
ref_z_i = ref_part.z
ref_z_i_LPA = ref_z_i - self.stage_start
ref_z_f = ref_z_i + self.surrogate_length

ref_part_tensor = torch.tensor(
[ref_part.x, ref_part.y, ref_z_i_LPA, ref_part.px, ref_part.py, ref_part.pz]
[
ref_part.x,
ref_part.y,
ref_z_i_LPA,
ref_part.px,
ref_part.py,
ref_part.pz,
],
dtype=torch.float64,
device=device,
)
ref_beta_gamma = np.sqrt(torch.sum(ref_part_tensor[3:] ** 2))
ref_beta_gamma = torch.sqrt(torch.sum(ref_part_tensor[3:] ** 2))

with torch.no_grad():
ref_part_model_final = self.surrogate_model(ref_part_tensor.float())
ref_part_model_final = self.surrogate_model(ref_part_tensor)
ref_uz_f = ref_part_model_final[5]
ref_beta_gamma_final = (
ref_uz_f # NOT np.sqrt(torch.sum(ref_part_model_final[3:]**2))
)
ref_part_final = torch.tensor([0, 0, ref_z_f, 0, 0, ref_uz_f])
ref_part_final = torch.tensor(
[0, 0, ref_z_f, 0, 0, ref_uz_f], dtype=torch.float64, device=device
)

# transform
coordinate_transformation(pc, TransformationDirection.to_fixed_t)
coordinate_transformation(pc, direction=CoordSystem.t)

for lvl in range(pc.finest_level + 1):
for pti in ImpactXParIter(pc, level=lvl):
aos = pti.aos()
aos_arr = array(aos, copy=False)

soa = pti.soa()
real_arrays = soa.GetRealData()
px = array(real_arrays[0], copy=False)
py = array(real_arrays[1], copy=False)
pt = array(real_arrays[2], copy=False)
data_arr = (
torch.tensor(
np.vstack(
[aos_arr["x"], aos_arr["y"], aos_arr["z"], real_arrays[:3]]
)
)
.float()
.T
real_arrays = soa.get_real_data()
x = array(real_arrays[0], copy=False)
y = array(real_arrays[1], copy=False)
t = array(real_arrays[2], copy=False)
px = array(real_arrays[3], copy=False)
py = array(real_arrays[4], copy=False)
pt = array(real_arrays[5], copy=False)
data_arr = torch.tensor(
stack(
[
x,
y,
t,
px,
py,
py,
],
axis=1,
),
dtype=torch.float64,
device=device,
)

data_arr[:, 0] += ref_part.x
Expand All @@ -135,7 +179,7 @@ def surrogate_push(self, pc, step):
# # assume for now it is

with torch.no_grad():
data_arr_post_model = self.surrogate_model(data_arr.float())
data_arr_post_model = self.surrogate_model(data_arr)

# need to add stage start to z
data_arr_post_model[:, 2] += self.stage_start
Expand All @@ -146,9 +190,9 @@ def surrogate_push(self, pc, step):
data_arr_post_model[:, 3 + ii] -= ref_part_final[3 + ii]
data_arr_post_model[:, 3 + ii] /= ref_beta_gamma_final

aos_arr["x"] = data_arr_post_model[:, 0]
aos_arr["y"] = data_arr_post_model[:, 1]
aos_arr["z"] = data_arr_post_model[:, 2]
x[:] = data_arr_post_model[:, 0]
y[:] = data_arr_post_model[:, 1]
t[:] = data_arr_post_model[:, 2]
px[:] = data_arr_post_model[:, 3]
py[:] = data_arr_post_model[:, 4]
pt[:] = data_arr_post_model[:, 5]
Expand All @@ -160,7 +204,7 @@ def surrogate_push(self, pc, step):
ref_part.x = ref_part_final[0]
ref_part.y = ref_part_final[1]
ref_part.z = ref_part_final[2]
ref_gamma = np.sqrt(1 + ref_beta_gamma_final**2)
ref_gamma = torch.sqrt(1 + ref_beta_gamma_final**2)
ref_part.px = ref_part_final[3]
ref_part.py = ref_part_final[4]
ref_part.pz = ref_part_final[5]
Expand All @@ -173,7 +217,7 @@ def surrogate_push(self, pc, step):
# ref_part.s += pge1.ds
# ref_part.t += pge1.ds / ref_beta

coordinate_transformation(pc, TransformationDirection.to_fixed_s)
coordinate_transformation(pc, direction=CoordSystem.s)
## Done!


Expand Down
26 changes: 17 additions & 9 deletions examples/pytorch_surrogate_model/surrogate_model_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,10 @@ def __init__(self, n_in, n_out, n_hidden_nodes, n_hidden_layers, act):
class surrogate_model:
""" """

def __init__(self, dataset_file, model_file):
def __init__(self, dataset_file, model_file, device):
self.dataset = torch.load(dataset_file)
model_dict = torch.load(model_file, map_location=torch.device("cpu"))
self.device = device
model_dict = torch.load(model_file)
n_in = model_dict["model_state_dict"]["stack.0.weight"].shape[1]
final_layer_key = list(model_dict["model_state_dict"].keys())[-1]
n_out = model_dict["model_state_dict"][final_layer_key].shape[0]
Expand All @@ -112,13 +113,20 @@ def __init__(self, dataset_file, model_file):
self.neural_network.load_state_dict(model_dict["model_state_dict"])
self.neural_network.eval()

def __call__(self, data_arr):
data_arr -= self.dataset["source_means"]
data_arr /= self.dataset["source_stds"]
data_arr = data_arr.float()
def __call__(self, data_arr, device=None):
data_arr -= torch.tensor(
self.dataset["source_means"], dtype=torch.float64, device=device
)
data_arr /= torch.tensor(
self.dataset["source_stds"], dtype=torch.float64, device=device
)
with torch.no_grad():
data_arr_post_model = self.neural_network(data_arr)
data_arr_post_model = self.neural_network(data_arr.float()).double()

data_arr_post_model *= self.dataset["target_stds"]
data_arr_post_model += self.dataset["target_means"]
data_arr_post_model *= torch.tensor(
self.dataset["target_stds"], dtype=torch.float64, device=device
)
data_arr_post_model += torch.tensor(
self.dataset["target_means"], dtype=torch.float64, device=device
)
return data_arr_post_model
18 changes: 10 additions & 8 deletions src/particles/CollectLost.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <AMReX_GpuLaunch.H>
#include <AMReX_GpuQualifiers.H>
#include <AMReX_Math.H>
#include <AMReX_Particle.H>
#include <AMReX_ParticleTransformation.H>
#include <AMReX_RandomEngine.H>

Expand All @@ -27,9 +28,9 @@ namespace impactx
using DstData = ImpactXParticleContainer::ParticleTileType::ParticleTileDataType;

AMREX_GPU_HOST_DEVICE
void operator() (DstData const &dst, SrcData const &src, int src_ip, int dst_ip) const noexcept {
dst.m_aos[dst_ip] = src.m_aos[src_ip];

void operator() (DstData const &dst, SrcData const &src, int src_ip, int dst_ip) const noexcept
{
dst.m_idcpu[dst_ip] = src.m_idcpu[src_ip];
for (int j = 0; j < SrcData::NAR; ++j)
dst.m_rdata[j][dst_ip] = src.m_rdata[j][src_ip];
for (int j = 0; j < src.m_num_runtime_real; ++j)
Expand All @@ -42,7 +43,7 @@ namespace impactx
// dst.m_runtime_idata[j][dst_ip] = src.m_runtime_idata[j][src_ip];

// flip id to positive in destination
dst.id(dst_ip) = amrex::Math::abs(dst.id(dst_ip));
amrex::ParticleIDWrapper{dst.m_idcpu[dst_ip]}.make_valid();

// remember the current s of the ref particle when lost
dst.m_runtime_rdata[s_index][dst_ip] = s_lost;
Expand Down Expand Up @@ -85,7 +86,7 @@ namespace impactx
auto const predicate = [] AMREX_GPU_HOST_DEVICE (const SrcData& src, int ip)
/* NVCC 11.3.109 chokes in C++17 on this: noexcept */
{
return src.id(ip) < 0;
return !amrex::ConstParticleIDWrapper{src.m_idcpu[ip]}.is_valid();
};

auto& ptile_dest = dest.DefineAndReturnParticleTile(
Expand Down Expand Up @@ -130,9 +131,11 @@ namespace impactx
{
int n_removed = 0;
auto ptile_src_data = ptile_source.getParticleTileData();
auto const ptile_soa = ptile_source.GetStructOfArrays();
auto const ptile_idcpu = ptile_soa.GetIdCPUData().dataPtr();
for (int ip = 0; ip < np; ++ip)
{
if (ptile_source.id(ip) < 0)
if (!amrex::ConstParticleIDWrapper{ptile_idcpu[ip]}.is_valid())
n_removed++;
else
{
Expand All @@ -141,8 +144,7 @@ namespace impactx
// move down
int const new_index = ip - n_removed;

ptile_src_data.m_aos[new_index] = ptile_src_data.m_aos[ip];

ptile_src_data.m_idcpu[new_index] = ptile_src_data.m_idcpu[ip];
for (int j = 0; j < SrcData::NAR; ++j)
ptile_src_data.m_rdata[j][new_index] = ptile_src_data.m_rdata[j][ip];
for (int j = 0; j < ptile_src_data.m_num_runtime_real; ++j)
Expand Down
Loading

0 comments on commit 9876a9e

Please sign in to comment.