diff --git a/mace_jax/data/utils.py b/mace_jax/data/utils.py index 7558cf0..c16706c 100644 --- a/mace_jax/data/utils.py +++ b/mace_jax/data/utils.py @@ -194,9 +194,7 @@ def load_from_xyz( class AtomicNumberTable: def __init__(self, zs: Sequence[int]): - zs = list(zs) - # integers - assert all(isinstance(z, int) for z in zs) + zs = [int(z) for z in zs] # unique assert len(zs) == len(set(zs)) # sorted diff --git a/mace_jax/modules/__init__.py b/mace_jax/modules/__init__.py index 4f91339..32002b4 100644 --- a/mace_jax/modules/__init__.py +++ b/mace_jax/modules/__init__.py @@ -7,7 +7,7 @@ RadialEmbeddingBlock, ScaleShiftBlock, ) -from .loss import WeightedEnergyForcesStressLoss +from .loss import WeightedEnergyForcesStressLoss, uber_loss from .message_passing import MessagePassingConvolution from .models import MACE from .symmetric_contraction import SymmetricContraction @@ -21,6 +21,7 @@ "RadialEmbeddingBlock", "ScaleShiftBlock", "WeightedEnergyForcesStressLoss", + "uber_loss", "MessagePassingConvolution", "MACE", "SymmetricContraction", diff --git a/mace_jax/modules/loss.py b/mace_jax/modules/loss.py index 54b4357..89bbf48 100644 --- a/mace_jax/modules/loss.py +++ b/mace_jax/modules/loss.py @@ -62,3 +62,10 @@ def __repr__(self): f"forces_weight={self.forces_weight:.3f}, " f"stress_weight={self.stress_weight:.3f})" ) + + +def uber_loss(x, t=1.0): + x_center = jnp.where(jnp.abs(x) <= 1.01 * t, x, 0.0) + center = x_center**2 / (2 * t) + sides = jnp.abs(x) - t / 2 + return jnp.where(jnp.abs(x) <= t, center, sides) diff --git a/mace_jax/tools/train.py b/mace_jax/tools/train.py index 09cd76c..6a0bd5d 100644 --- a/mace_jax/tools/train.py +++ b/mace_jax/tools/train.py @@ -22,6 +22,7 @@ def train( optimizer_state: Dict[str, Any], steps_per_interval: int, ema_decay: Optional[float] = None, + progress_bar: bool = True, ): """ for interval, params, optimizer_state, ema_params in train(...): @@ -33,6 +34,7 @@ def train( logging.info("Started training") @partial(jax.pmap, in_axes=(None, 0), out_axes=0) + # @partial(jax.vmap, in_axes=(None, 0), out_axes=0) def grad_fn(params, graph: jraph.GraphsTuple): # graph is assumed to be padded by jraph.pad_with_graphs mask = jraph.get_graph_padding_mask(graph) # [n_graphs,] @@ -85,6 +87,7 @@ def interval_loader(): interval_loader(), desc=f"Train interval {interval}", total=steps_per_interval, + disable=not progress_bar, ) for graph in p_bar: num_updates += 1