Skip to content

Commit

Permalink
Fold trajectory keywise
Browse files Browse the repository at this point in the history
  • Loading branch information
ashenoy463 committed May 6, 2024
1 parent cb70dbc commit 8a6f610
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 24 deletions.
27 changes: 27 additions & 0 deletions mdx/helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import h5py
from collections.abc import MutableMapping
import numpy as np
import dask.array as da


# thints
Expand Down Expand Up @@ -41,6 +42,32 @@ def dict_flatten(dictionary, parent_key="", separator="_"):
return dict(items)


getkeys = lambda x: list(x.keys())


def concat_keywise(x: dict, y: dict) -> dict:
"""
Binary operation on dicts of Numpy-like arrays, returns dict
with arrays concenated along axis 0 keywise.
Args:
x (dict[ArrayLike])
y (dict[ArrayLike])
Returns:
dict[str,ArrayLike]
"""
# TEMPFIX: not specifying keys makes merging partitions more painful
# keys = getkeys(x) if len(getkeys(x)) != 0 else y
keys = ["r", "v", "q", "type", "box"]
z = {}
for k in keys:
z[k] = da.concatenate([x[k], y[k]], axis=0)
if z == {}:
print(getkeys(x), getkeys(y))
return z


def write_rdf(rdf_obj, sim_id: str, chunk_id: str, c: str, s: str, suffix=None):
"""
Store MDAnalysis InterRDF results to HDF5 file
Expand Down
60 changes: 36 additions & 24 deletions mdx/ingest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import sparse
import dask.bag as db
import dask.array as da
import dask.dataframe as df
import re
import os
import pandas as pd
Expand All @@ -11,7 +10,7 @@
from mdx.models.meta import FormatMeta
from pydantic import PositiveInt, ValidationError, validate_call
from typing import Union

import mdx.helper_functions as hf

# Exceptions

Expand Down Expand Up @@ -51,6 +50,8 @@ class Simulation:
"""

VALID_TRAJ_FORMATS = ["xarray", "frame"]

def __init__(
self,
meta_file: os.PathLike,
Expand Down Expand Up @@ -81,7 +82,7 @@ def __init__(
def read_trajectory(
self,
data_path: os.PathLike = None,
atomic_format: str = "frame",
format: str = "xarray",
blocksize: str = None,
) -> None:
"""
Expand All @@ -100,20 +101,35 @@ def read_trajectory(
.remove(lambda x: x == "ITEM: TIMESTEP")
.map(lambda x: x.split("ITEM: "))
.map(lambda x: x[:-1] if (x[-1] == "TIMESTEP") else x)
.map(self.__process_traj_step, atomic_format=atomic_format)
.map(
.map(self.__process_traj_step)
# .distinct(key=lambda x: x["timestep"]) ; causes memory leak on nanosecond scale data
)
if format == "xarray":
corpus = corpus.map(
lambda x: dict(
# TEMPFIX: Limited columns parsed into arrays for now
step=x["timestep"],
r=da.from_array(x["atomic"][["xu", "yu", "zu"]].values),
v=da.from_array(x["atomic"][["vx", "vy", "vz"]].values),
q=da.from_array(x["atomic"]["q"].values),
type=da.from_array(x["atomic"]["type"].values),
box=x["box"]["bounds"],
r=da.expand_dims(
da.from_array(x["atomic"][["xu", "yu", "zu"]].values), axis=0
),
v=da.expand_dims(
da.from_array(x["atomic"][["vx", "vy", "vz"]].values), axis=0
),
q=da.expand_dims(da.from_array(x["atomic"]["q"].values), axis=0),
type=da.expand_dims(
da.from_array(x["atomic"]["type"].values), axis=0
),
box=da.expand_dims(x["box"]["bounds"], axis=0),
)
).fold(hf.concat_keywise)
elif format == "frame":
# TEMPFIX: distrinct causes memory leak on nanosecond scale
# corpus = corpus.distinct(key=lambda x: x["timestep"])
pass
else:
raise InvalidFormat(
"Invalid output format passed. Valid options are {self.VALID_TRAJ_FORMATS}"
)
# .distinct(key=lambda x: x["timestep"]) ; causes memory leak on nanosecond scale data
)

self.trajectory = corpus.compute() if self.eager else corpus

Expand Down Expand Up @@ -208,9 +224,9 @@ def read_ave(self):

# Intermediate processing steps

def __process_traj_step(self, step_text: str, atomic_format: str):
def __process_traj_step(self, step_text: str):
"""
Parse raw trajectory data text of one frame into chosen format
Parse raw trajectory data text of one timestep into a frame
"""
frame = {"timestep": "", "n_atoms": "", "atomic": ""}
item_regex = "([A-Z ]*)([A-z ]*)\n((.*[\n]?)*)"
Expand Down Expand Up @@ -244,16 +260,12 @@ def __process_traj_step(self, step_text: str, atomic_format: str):

elif label == "ATOMS":

if atomic_format == "frame":
atomic_data = (
pd.read_csv(io.StringIO(data), sep=" ", names=header.split())
.set_index("id")
.sort_index()
)
frame["atomic"] = atomic_data

else:
raise InvalidFormat("Select a valid atomic output format")
atomic_data = (
pd.read_csv(io.StringIO(data), sep=" ", names=header.split())
.set_index("id")
.sort_index()
)
frame["atomic"] = atomic_data

elif label == "DIMENSIONS":
# ??Grid??
Expand Down

0 comments on commit 8a6f610

Please sign in to comment.