From 42bd6b4b94501b2b03ccab476a87fc489f5a4b87 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 12 Jul 2023 15:44:54 -0400 Subject: [PATCH] add ParallelLoader --- mace_jax/data/__init__.py | 2 ++ mace_jax/data/utils.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/mace_jax/data/__init__.py b/mace_jax/data/__init__.py index c09ebbb..01a88b3 100644 --- a/mace_jax/data/__init__.py +++ b/mace_jax/data/__init__.py @@ -13,6 +13,7 @@ GraphEdges, GraphGlobals, GraphDataLoader, + ParallelLoader, AtomicNumberTable, atomic_numbers_to_indices, get_atomic_number_table_from_zs, @@ -34,6 +35,7 @@ "GraphEdges", "GraphGlobals", "GraphDataLoader", + "ParallelLoader", "AtomicNumberTable", "atomic_numbers_to_indices", "get_atomic_number_table_from_zs", diff --git a/mace_jax/data/utils.py b/mace_jax/data/utils.py index c16706c..4750155 100644 --- a/mace_jax/data/utils.py +++ b/mace_jax/data/utils.py @@ -424,6 +424,35 @@ def subset(self, i): n_mantissa_bits=self.n_mantissa_bits, ) + def replace_graphs(self, graphs: List[jraph.GraphsTuple]): + return GraphDataLoader( + graphs=graphs, + n_node=self.n_node, + n_edge=self.n_edge, + n_graph=self.n_graph, + min_n_node=self.min_n_node, + min_n_edge=self.min_n_edge, + min_n_graph=self.min_n_graph, + shuffle=self.shuffle, + n_mantissa_bits=self.n_mantissa_bits, + ) + + +class ParallelLoader: + def __init__(self, loader, n: int): + self.loader = loader + self.n = n + + def __iter__(self): + it = iter(self.loader) + while True: + try: + yield jax.tree_map( + lambda *x: jnp.stack(x), *[next(it) for _ in range(self.n)] + ) + except StopIteration: + return + def pad_graph_to_nearest_ceil_mantissa( graphs_tuple: jraph.GraphsTuple,