Skip to content

Commit

Permalink
add ParallelLoader
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Jul 12, 2023
1 parent be82f6b commit 42bd6b4
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
2 changes: 2 additions & 0 deletions mace_jax/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
GraphEdges,
GraphGlobals,
GraphDataLoader,
ParallelLoader,
AtomicNumberTable,
atomic_numbers_to_indices,
get_atomic_number_table_from_zs,
Expand All @@ -34,6 +35,7 @@
"GraphEdges",
"GraphGlobals",
"GraphDataLoader",
"ParallelLoader",
"AtomicNumberTable",
"atomic_numbers_to_indices",
"get_atomic_number_table_from_zs",
Expand Down
29 changes: 29 additions & 0 deletions mace_jax/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,35 @@ def subset(self, i):
n_mantissa_bits=self.n_mantissa_bits,
)

def replace_graphs(self, graphs: List[jraph.GraphsTuple]):
return GraphDataLoader(
graphs=graphs,
n_node=self.n_node,
n_edge=self.n_edge,
n_graph=self.n_graph,
min_n_node=self.min_n_node,
min_n_edge=self.min_n_edge,
min_n_graph=self.min_n_graph,
shuffle=self.shuffle,
n_mantissa_bits=self.n_mantissa_bits,
)


class ParallelLoader:
def __init__(self, loader, n: int):
self.loader = loader
self.n = n

def __iter__(self):
it = iter(self.loader)
while True:
try:
yield jax.tree_map(
lambda *x: jnp.stack(x), *[next(it) for _ in range(self.n)]
)
except StopIteration:
return


def pad_graph_to_nearest_ceil_mantissa(
graphs_tuple: jraph.GraphsTuple,
Expand Down

0 comments on commit 42bd6b4

Please sign in to comment.