Skip to content

Commit

Permalink
uber loss
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Jul 8, 2023
1 parent 042a21f commit be82f6b
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 4 deletions.
4 changes: 1 addition & 3 deletions mace_jax/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,7 @@ def load_from_xyz(

class AtomicNumberTable:
def __init__(self, zs: Sequence[int]):
zs = list(zs)
# integers
assert all(isinstance(z, int) for z in zs)
zs = [int(z) for z in zs]
# unique
assert len(zs) == len(set(zs))
# sorted
Expand Down
3 changes: 2 additions & 1 deletion mace_jax/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
RadialEmbeddingBlock,
ScaleShiftBlock,
)
from .loss import WeightedEnergyForcesStressLoss
from .loss import WeightedEnergyForcesStressLoss, uber_loss
from .message_passing import MessagePassingConvolution
from .models import MACE
from .symmetric_contraction import SymmetricContraction
Expand All @@ -21,6 +21,7 @@
"RadialEmbeddingBlock",
"ScaleShiftBlock",
"WeightedEnergyForcesStressLoss",
"uber_loss",
"MessagePassingConvolution",
"MACE",
"SymmetricContraction",
Expand Down
7 changes: 7 additions & 0 deletions mace_jax/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,10 @@ def __repr__(self):
f"forces_weight={self.forces_weight:.3f}, "
f"stress_weight={self.stress_weight:.3f})"
)


def uber_loss(x, t=1.0):
x_center = jnp.where(jnp.abs(x) <= 1.01 * t, x, 0.0)
center = x_center**2 / (2 * t)
sides = jnp.abs(x) - t / 2
return jnp.where(jnp.abs(x) <= t, center, sides)
3 changes: 3 additions & 0 deletions mace_jax/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def train(
optimizer_state: Dict[str, Any],
steps_per_interval: int,
ema_decay: Optional[float] = None,
progress_bar: bool = True,
):
"""
for interval, params, optimizer_state, ema_params in train(...):
Expand All @@ -33,6 +34,7 @@ def train(
logging.info("Started training")

@partial(jax.pmap, in_axes=(None, 0), out_axes=0)
# @partial(jax.vmap, in_axes=(None, 0), out_axes=0)
def grad_fn(params, graph: jraph.GraphsTuple):
# graph is assumed to be padded by jraph.pad_with_graphs
mask = jraph.get_graph_padding_mask(graph) # [n_graphs,]
Expand Down Expand Up @@ -85,6 +87,7 @@ def interval_loader():
interval_loader(),
desc=f"Train interval {interval}",
total=steps_per_interval,
disable=not progress_bar,
)
for graph in p_bar:
num_updates += 1
Expand Down

0 comments on commit be82f6b

Please sign in to comment.