From ff078d4a126af3a8728b9a3afd790056521f090e Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Sun, 29 Jan 2023 21:11:14 -0500 Subject: [PATCH 1/7] No epochs --- configs/aspirin.gin | 12 +- configs/aspirin_small.gin | 14 +- mace_jax/plot_train.py | 29 ++-- mace_jax/run_train.py | 5 +- mace_jax/tools/gin_functions.py | 233 ++++++++++++++++---------------- mace_jax/tools/gin_model.py | 4 +- mace_jax/tools/train.py | 43 +++--- 7 files changed, 164 insertions(+), 176 deletions(-) diff --git a/configs/aspirin.gin b/configs/aspirin.gin index a1edfbe..137e224 100644 --- a/configs/aspirin.gin +++ b/configs/aspirin.gin @@ -20,9 +20,7 @@ datasets.n_mantissa_bits = 3 model.radial_basis = @bessel_basis bessel_basis.number = 8 -# model.radial_envelope = @u_envelope # 10. @ epoch 350 -# u_envelope.p = 5 -model.radial_envelope = @soft_envelope # 9.5 @ epoch 350 +model.radial_envelope = @soft_envelope model.symmetric_tensor_product_basis = True # symmetric is slightly worse but slightly faster model.off_diagonal = False @@ -52,15 +50,15 @@ loss.stress_weight = 0.0 optimizer.algorithm = @amsgrad # @adam is slightly worse optimizer.lr = 0.01 -optimizer.max_num_epochs = 100 +optimizer.steps_per_interval = 1024 +optimizer.max_num_intervals = 100 # optimizer.weight_decay = 5e-7 # optimizer.scheduler = @piecewise_constant_schedule # piecewise_constant_schedule.boundaries_and_scales = { -# 100: 0.1, # divide the learning rate by 10 after 100 epochs -# 1000: 0.1, # divide the learning rate by 10 after 1000 epochs +# 100: 0.1, # divide the learning rate by 10 after 100 intervals +# 1000: 0.1, # divide the learning rate by 10 after 1000 intervals # } -train.eval_interval = 5 train.patience = 2048 train.ema_decay = 0.99 train.eval_train = False # if True, evaluates the whole training set at each eval_interval diff --git a/configs/aspirin_small.gin b/configs/aspirin_small.gin index 4fc871d..56092f0 100644 --- a/configs/aspirin_small.gin +++ b/configs/aspirin_small.gin @@ -21,9 +21,7 @@ datasets.n_mantissa_bits = 3 model.radial_basis = @bessel_basis bessel_basis.number = 8 -# model.radial_envelope = @u_envelope # 10. @ epoch 350 -# u_envelope.p = 5 -model.radial_envelope = @soft_envelope # 9.5 @ epoch 350 +model.radial_envelope = @soft_envelope model.symmetric_tensor_product_basis = True # symmetric is slightly worse but slightly faster model.off_diagonal = False @@ -53,17 +51,17 @@ loss.stress_weight = 0.0 optimizer.algorithm = @sgd optimizer.lr = 0.01 -optimizer.max_num_epochs = 2 +optimizer.steps_per_interval = 100 +optimizer.max_num_intervals = 2 # optimizer.weight_decay = 5e-7 # optimizer.scheduler = @piecewise_constant_schedule # piecewise_constant_schedule.boundaries_and_scales = { -# 100: 0.1, # divide the learning rate by 10 after 100 epochs -# 1000: 0.1, # divide the learning rate by 10 after 1000 epochs +# 100: 0.1, # divide the learning rate by 10 after 100 intervals +# 1000: 0.1, # divide the learning rate by 10 after 1000 intervals # } -train.eval_interval = 5 train.patience = 2048 train.ema_decay = 0.99 -train.eval_train = False # if True, evaluates the whole training set at each eval_interval +train.eval_train = True # if True, evaluates the whole training set at each eval_interval train.eval_test = False train.log_errors = "PerAtomMAE" diff --git a/mace_jax/plot_train.py b/mace_jax/plot_train.py index afc67bd..2046fae 100644 --- a/mace_jax/plot_train.py +++ b/mace_jax/plot_train.py @@ -44,17 +44,12 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--path", help="path to results file or directory", required=True ) - parser.add_argument( - "--min_epoch", help="minimum epoch", default=50, type=int, required=False - ) return parser.parse_args() -def plot(data: pd.DataFrame, min_epoch: int, output_path: str) -> None: - data = data[data["epoch"] > min_epoch] - +def plot(data: pd.DataFrame, output_path: str) -> None: data = ( - data.groupby(["path", "name", "mode", "epoch"]) + data.groupby(["path", "name", "mode", "interval"]) .agg([np.mean, np.std]) .reset_index() ) @@ -68,14 +63,14 @@ def plot(data: pd.DataFrame, min_epoch: int, output_path: str) -> None: ax = axes[0] ax.plot( - valid_data["epoch"], + valid_data["interval"], valid_data["loss"]["mean"], color=colors[0], zorder=1, label="Validation", ) # ax.fill_between( - # x=valid_data["epoch"], + # x=valid_data["interval"], # y1=valid_data["loss"]["mean"] - valid_data["loss"]["std"], # y2=valid_data["loss"]["mean"] + valid_data["loss"]["std"], # alpha=0.5, @@ -83,14 +78,14 @@ def plot(data: pd.DataFrame, min_epoch: int, output_path: str) -> None: # color=colors[0], # ) ax.plot( - train_data["epoch"], + train_data["interval"], train_data["loss"]["mean"], color=colors[3], zorder=1, label="Training", ) # ax.fill_between( - # x=train_data["epoch"], + # x=train_data["interval"], # y1=train_data["loss"]["mean"] - train_data["loss"]["std"], # y2=train_data["loss"]["mean"] + train_data["loss"]["std"], # alpha=0.5, @@ -100,20 +95,20 @@ def plot(data: pd.DataFrame, min_epoch: int, output_path: str) -> None: ax.set_xscale("log") ax.set_yscale("log") - ax.set_xlabel("Epoch") + ax.set_xlabel("Interval") ax.set_ylabel("Loss") ax.legend() ax = axes[1] ax.plot( - valid_data["epoch"], + valid_data["interval"], valid_data["mae_e"]["mean"], color=colors[1], zorder=1, label="MAE Energy [eV]", ) # ax.fill_between( - # x=valid_data["epoch"], + # x=valid_data["interval"], # y1=valid_data["mae_e"]["mean"] - valid_data["mae_e"]["std"], # y2=valid_data["mae_e"]["mean"] + valid_data["mae_e"]["std"], # alpha=0.5, @@ -121,14 +116,14 @@ def plot(data: pd.DataFrame, min_epoch: int, output_path: str) -> None: # color=colors[1], # ) ax.plot( - valid_data["epoch"], + valid_data["interval"], valid_data["mae_f"]["mean"], color=colors[2], zorder=1, label="MAE Forces [eV/Å]", ) # ax.fill_between( - # x=valid_data["epoch"], + # x=valid_data["interval"], # y1=valid_data["mae_f"]["mean"] - valid_data["mae_f"]["std"], # y2=valid_data["mae_f"]["mean"] + valid_data["mae_f"]["std"], # alpha=0.5, @@ -166,7 +161,7 @@ def main(): ) for (path, name), group in data.groupby(["path", "name"]): - plot(group, min_epoch=args.min_epoch, output_path=f"{path}/{name}.pdf") + plot(group, output_path=f"{path}/{name}.pdf") if __name__ == "__main__": diff --git a/mace_jax/run_train.py b/mace_jax/run_train.py index 7ed7142..19d8091 100644 --- a/mace_jax/run_train.py +++ b/mace_jax/run_train.py @@ -43,7 +43,7 @@ def main(): if checks(predictor, params, train_loader): return - gradient_transform, max_num_epochs = optimizer(train_loader.approx_length()) + gradient_transform, steps_per_interval, max_num_intervals = optimizer() optimizer_state = gradient_transform.init(params) logging.info(f"Number of parameters: {tools.count_parameters(params)}") @@ -59,7 +59,8 @@ def main(): valid_loader, test_loader, gradient_transform, - max_num_epochs, + max_num_intervals, + steps_per_interval, logger, directory, tag, diff --git a/mace_jax/tools/gin_functions.py b/mace_jax/tools/gin_functions.py index 5745bb4..58742d0 100644 --- a/mace_jax/tools/gin_functions.py +++ b/mace_jax/tools/gin_functions.py @@ -110,7 +110,7 @@ def checks( @gin.configurable def exponential_decay( lr: float, - steps_per_epoch: int, + steps_per_interval: int, *, transition_steps: float = 0.0, decay_rate: float = 0.5, @@ -120,9 +120,9 @@ def exponential_decay( ): return optax.exponential_decay( init_value=lr, - transition_steps=transition_steps * steps_per_epoch, + transition_steps=transition_steps * steps_per_interval, decay_rate=decay_rate, - transition_begin=transition_begin * steps_per_epoch, + transition_begin=transition_begin * steps_per_interval, staircase=staircase, end_value=end_value, ) @@ -130,10 +130,10 @@ def exponential_decay( @gin.configurable def piecewise_constant_schedule( - lr: float, steps_per_epoch: int, *, boundaries_and_scales: Dict[float, float] + lr: float, steps_per_interval: int, *, boundaries_and_scales: Dict[float, float] ): boundaries_and_scales = { - boundary * steps_per_epoch: scale + boundary * steps_per_interval: scale for boundary, scale in boundaries_and_scales.items() } return optax.piecewise_constant_schedule( @@ -142,7 +142,7 @@ def piecewise_constant_schedule( @gin.register -def constant_schedule(lr, steps_per_epoch): +def constant_schedule(lr, steps_per_interval): return optax.constant_schedule(lr) @@ -153,10 +153,10 @@ def constant_schedule(lr, steps_per_epoch): @gin.configurable def optimizer( - steps_per_epoch: int, + steps_per_interval: int, + max_num_intervals: int, weight_decay=0.0, lr=0.01, - max_num_epochs: int = 2048, algorithm: Callable = optax.scale_by_adam, scheduler: Callable = constant_schedule, ): @@ -174,10 +174,11 @@ def weight_decay_mask(params): optax.chain( optax.add_decayed_weights(weight_decay, mask=weight_decay_mask), algorithm(), - optax.scale_by_schedule(scheduler(lr, steps_per_epoch)), + optax.scale_by_schedule(scheduler(lr, steps_per_interval)), optax.scale(-1.0), # Gradient descent. ), - max_num_epochs, + steps_per_interval, + max_num_intervals, ) @@ -190,7 +191,8 @@ def train( valid_loader, test_loader, gradient_transform, - max_num_epochs, + max_num_intervals: int, + steps_per_interval: int, logger, directory, tag, @@ -198,7 +200,6 @@ def train( patience: Optional[int] = None, eval_train: bool = False, eval_test: bool = False, - eval_interval: int = 1, log_errors: str = "PerAtomRMSE", **kwargs, ): @@ -206,21 +207,20 @@ def train( patience_counter = 0 loss_fn = loss() start_time = time.perf_counter() - total_time_per_epoch = [] - eval_time_per_epoch = [] + total_time_per_interval = [] + eval_time_per_interval = [] - for epoch, params, optimizer_state, ema_params in tools.train( + for interval, params, optimizer_state, ema_params in tools.train( model=model, params=params, loss_fn=loss_fn, train_loader=train_loader, gradient_transform=gradient_transform, optimizer_state=optimizer_state, - start_epoch=0, - logger=logger, + steps_per_interval=steps_per_interval, **kwargs, ): - total_time_per_epoch += [time.perf_counter() - start_time] + total_time_per_interval += [time.perf_counter() - start_time] start_time = time.perf_counter() try: @@ -230,114 +230,113 @@ def train( else: profile_nn_jax.restart_timer() - last_epoch = epoch == max_num_epochs - if epoch % eval_interval == 0 or last_epoch: - with open(f"{directory}/{tag}.pkl", "wb") as f: - pickle.dump(gin.operative_config_str(), f) - pickle.dump(params, f) - - def eval_and_print(loader, mode: str): - loss_, metrics_ = tools.evaluate( - model=model, - params=ema_params, - loss_fn=loss_fn, - data_loader=loader, - ) - metrics_["mode"] = mode - metrics_["epoch"] = epoch - logger.log(metrics_) - - if log_errors == "PerAtomRMSE": - error_e = "rmse_e_per_atom" - error_f = "rmse_f" - error_s = "rmse_s" - elif log_errors == "rel_PerAtomRMSE": - error_e = "rmse_e_per_atom" - error_f = "rel_rmse_f" - error_s = "rel_rmse_s" - elif log_errors == "TotalRMSE": - error_e = "rmse_e" - error_f = "rmse_f" - error_s = "rmse_s" - elif log_errors == "PerAtomMAE": - error_e = "mae_e_per_atom" - error_f = "mae_f" - error_s = "mae_s" - elif log_errors == "rel_PerAtomMAE": - error_e = "mae_e_per_atom" - error_f = "rel_mae_f" - error_s = "rel_mae_s" - elif log_errors == "TotalMAE": - error_e = "mae_e" - error_f = "mae_f" - error_s = "mae_s" - - def _(x: str): - v: float = metrics_.get(x, None) - if v is None: - return "N/A" - if x.startswith("rel_"): - return f"{100 * v:.1f}%" - if "_e" in x: - return f"{1e3 * v:.1f} meV" - if "_f" in x: - return f"{1e3 * v:.1f} meV/Å" - if "_s" in x: - return f"{1e3 * v:.1f} meV/ų" - raise NotImplementedError - - logging.info( - f"Epoch {epoch}: {mode}: " - f"loss={loss_:.4f}, " - f"{error_e}={_(error_e)}, " - f"{error_f}={_(error_f)}, " - f"{error_s}={_(error_s)}" - ) - return loss_ + last_interval = interval == max_num_intervals - if eval_train or last_epoch: - if isinstance(eval_train, (int, float)): - eval_and_print(train_loader.subset(eval_train), "eval_train") - else: - eval_and_print(train_loader, "eval_train") - - if ( - (eval_test or last_epoch) - and test_loader is not None - and len(test_loader) > 0 - ): - eval_and_print(test_loader, "eval_test") - - if valid_loader is not None and len(valid_loader) > 0: - loss_ = eval_and_print(valid_loader, "eval_valid") - - if loss_ >= lowest_loss: - patience_counter += 1 - if patience is not None and patience_counter >= patience: - logging.info( - f"Stopping optimization after {patience_counter} epochs without improvement" - ) - break - else: - lowest_loss = loss_ - patience_counter = 0 + with open(f"{directory}/{tag}.pkl", "wb") as f: + pickle.dump(gin.operative_config_str(), f) + pickle.dump(params, f) - eval_time_per_epoch += [time.perf_counter() - start_time] - avg_time_per_epoch = np.mean(total_time_per_epoch[-eval_interval:]) - avg_eval_time_per_epoch = np.mean(eval_time_per_epoch[-eval_interval:]) + def eval_and_print(loader, mode: str): + loss_, metrics_ = tools.evaluate( + model=model, + params=ema_params, + loss_fn=loss_fn, + data_loader=loader, + name=mode, + ) + metrics_["mode"] = mode + metrics_["interval"] = interval + logger.log(metrics_) + + if log_errors == "PerAtomRMSE": + error_e = "rmse_e_per_atom" + error_f = "rmse_f" + error_s = "rmse_s" + elif log_errors == "rel_PerAtomRMSE": + error_e = "rmse_e_per_atom" + error_f = "rel_rmse_f" + error_s = "rel_rmse_s" + elif log_errors == "TotalRMSE": + error_e = "rmse_e" + error_f = "rmse_f" + error_s = "rmse_s" + elif log_errors == "PerAtomMAE": + error_e = "mae_e_per_atom" + error_f = "mae_f" + error_s = "mae_s" + elif log_errors == "rel_PerAtomMAE": + error_e = "mae_e_per_atom" + error_f = "rel_mae_f" + error_s = "rel_mae_s" + elif log_errors == "TotalMAE": + error_e = "mae_e" + error_f = "mae_f" + error_s = "mae_s" + + def _(x: str): + v: float = metrics_.get(x, None) + if v is None: + return "N/A" + if x.startswith("rel_"): + return f"{100 * v:.1f}%" + if "_e" in x: + return f"{1e3 * v:.1f} meV" + if "_f" in x: + return f"{1e3 * v:.1f} meV/Å" + if "_s" in x: + return f"{1e3 * v:.1f} meV/ų" + raise NotImplementedError logging.info( - f"Epoch {epoch}: Time per epoch: {avg_time_per_epoch:.1f}s, " - f"among which {avg_eval_time_per_epoch:.1f}s for evaluation." + f"Interval {interval}: {mode}: " + f"loss={loss_:.4f}, " + f"{error_e}={_(error_e)}, " + f"{error_f}={_(error_f)}, " + f"{error_s}={_(error_s)}" ) - else: - eval_time_per_epoch += [time.perf_counter() - start_time] # basically 0 + return loss_ + + if eval_train or last_interval: + if isinstance(eval_train, (int, float)): + eval_and_print(train_loader.subset(eval_train), "eval_train") + else: + eval_and_print(train_loader, "eval_train") + + if ( + (eval_test or last_interval) + and test_loader is not None + and len(test_loader) > 0 + ): + eval_and_print(test_loader, "eval_test") + + if valid_loader is not None and len(valid_loader) > 0: + loss_ = eval_and_print(valid_loader, "eval_valid") + + if loss_ >= lowest_loss: + patience_counter += 1 + if patience is not None and patience_counter >= patience: + logging.info( + f"Stopping optimization after {patience_counter} intervals without improvement" + ) + break + else: + lowest_loss = loss_ + patience_counter = 0 + + eval_time_per_interval += [time.perf_counter() - start_time] + avg_time_per_interval = np.mean(total_time_per_interval[-3:]) + avg_eval_time_per_interval = np.mean(eval_time_per_interval[-3:]) + + logging.info( + f"Interval {interval}: Time per interval: {avg_time_per_interval:.1f}s, " + f"among which {avg_eval_time_per_interval:.1f}s for evaluation." + ) - if last_epoch: + if last_interval: break logging.info("Training complete") - return epoch, ema_params + return interval, ema_params def parse_argv(argv: List[str]): diff --git a/mace_jax/tools/gin_model.py b/mace_jax/tools/gin_model.py index dcc7436..2629062 100644 --- a/mace_jax/tools/gin_model.py +++ b/mace_jax/tools/gin_model.py @@ -67,7 +67,7 @@ def __init__(self, num_species: int, irreps_out: e3nn.Irreps): def __call__(self, node_specie: jnp.ndarray) -> e3nn.IrrepsArray: w = hk.get_parameter( - f"embeddings", + "embeddings", shape=(self.num_species, self.irreps_out.dim), dtype=jnp.float32, init=hk.initializers.RandomNormal(), @@ -118,7 +118,7 @@ def model( avg_r_min = tools.compute_avg_min_neighbor_distance(train_graphs) logging.info(f"Compute the average min neighbor distance: {avg_r_min:.3f}") elif avg_r_min is None: - logging.info(f"Do not normalize the radial basis (avg_r_min=None)") + logging.info("Do not normalize the radial basis (avg_r_min=None)") else: logging.info(f"Use the average min neighbor distance: {avg_r_min:.3f}") diff --git a/mace_jax/tools/train.py b/mace_jax/tools/train.py index b89c9ca..2aaf7b4 100644 --- a/mace_jax/tools/train.py +++ b/mace_jax/tools/train.py @@ -20,8 +20,7 @@ def train( train_loader: data.GraphDataLoader, gradient_transform: Any, optimizer_state: Dict[str, Any], - start_epoch: int, - logger: Any, + steps_per_interval: int, ema_decay: Optional[float] = None, ): num_updates = 0 @@ -53,14 +52,23 @@ def update_fn( last_cache_size = update_fn._cache_size() - for epoch in itertools.count(start_epoch): - yield epoch, params, optimizer_state, ema_params + def interval_loader(): + i = 0 + while True: + for graph in train_loader: + yield graph + i += 1 + if i >= steps_per_interval: + return - # Train one epoch + for interval in itertools.count(): + yield interval, params, optimizer_state, ema_params + + # Train one interval p_bar = tqdm.tqdm( - train_loader, - desc=f"Epoch {epoch}", - total=train_loader.approx_length(), + interval_loader(), + desc=f"Train interval {interval}", + total=steps_per_interval, ) for graph in p_bar: num_updates += 1 @@ -70,18 +78,6 @@ def update_fn( ) loss = float(loss) p_bar.set_postfix({"loss": f"{loss:7.3f}"}) - opt_metrics = { - "loss": loss, - "time": time.time() - start_time, - } - - opt_metrics["mode"] = "opt" - opt_metrics["epoch"] = epoch - opt_metrics["num_updates"] = num_updates - opt_metrics["epoch_"] = ( - start_epoch + num_updates / train_loader.approx_length() - ) - logger.log(opt_metrics) if last_cache_size != update_fn._cache_size(): last_cache_size = update_fn._cache_size() @@ -91,7 +87,7 @@ def update_fn( logging.info(f"- n_edge={graph.n_edge} total={graph.n_edge.sum()}") logging.info(f"Outout: loss= {loss:.3f}") logging.info( - f"Compilation time: {opt_metrics['time']:.3f}s, cache size: {last_cache_size}" + f"Compilation time: {time.time() - start_time:.3f}s, cache size: {last_cache_size}" ) @@ -100,6 +96,7 @@ def evaluate( params: Any, loss_fn: Any, data_loader: data.GraphDataLoader, + name: str = "Evaluation", ) -> Tuple[float, Dict[str, Any]]: total_loss = 0.0 num_graphs = 0 @@ -122,7 +119,7 @@ def evaluate( last_cache_size = None start_time = time.time() - p_bar = tqdm.tqdm(data_loader, desc="Evaluating", total=data_loader.approx_length()) + p_bar = tqdm.tqdm(data_loader, desc=name, total=data_loader.approx_length()) for ref_graph in p_bar: output = model(params, ref_graph) pred_graph = ref_graph._replace( @@ -178,7 +175,7 @@ def evaluate( stress_list.append(ref_graph.globals.stress) if num_graphs == 0: - logging.warning("No graphs in data_loader !") + logging.warning(f"No graphs in data_loader ! Returning 0.0 for {name}") return 0.0, {} avg_loss = total_loss / num_graphs From 9c14ed5c1a75c22686210cb2aa9add5e357e5daf Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Sun, 29 Jan 2023 21:31:12 -0500 Subject: [PATCH 2/7] jit predictor --- mace_jax/run_train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mace_jax/run_train.py b/mace_jax/run_train.py index 19d8091..f670c9d 100644 --- a/mace_jax/run_train.py +++ b/mace_jax/run_train.py @@ -2,6 +2,7 @@ import sys import gin +import jax import mace_jax from mace_jax import tools @@ -36,8 +37,8 @@ def main(): params = reload(params) - predictor = lambda w, g: tools.predict_energy_forces_stress( - lambda *x: model_fn(w, *x), g + predictor = jax.jit( + lambda w, g: tools.predict_energy_forces_stress(lambda *x: model_fn(w, *x), g) ) if checks(predictor, params, train_loader): From 32cde594dab164edebf77f1b2eb64465f7cd71ef Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 2 Feb 2023 18:49:46 -0500 Subject: [PATCH 3/7] test_num --- mace_jax/tools/gin_datasets.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mace_jax/tools/gin_datasets.py b/mace_jax/tools/gin_datasets.py index 08655d2..28b623e 100644 --- a/mace_jax/tools/gin_datasets.py +++ b/mace_jax/tools/gin_datasets.py @@ -19,6 +19,7 @@ def datasets( valid_fraction: float = None, valid_num: int = None, test_path: str = None, + test_num: int = None, seed: int = 1234, energy_key: str = "energy", forces_key: str = "forces", @@ -93,6 +94,7 @@ def datasets( energy_key=energy_key, forces_key=forces_key, extract_atomic_energies=False, + num_configs=test_num, prefactor_stress=prefactor_stress, remap_stress=remap_stress, ) From 0072f1857daa9c108bb2f5cfaac524bbdfad4acf Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 2 Feb 2023 19:10:42 -0500 Subject: [PATCH 4/7] safety --- mace_jax/run_train.py | 5 ++++- mace_jax/tools/gin_model.py | 13 +++++++------ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/mace_jax/run_train.py b/mace_jax/run_train.py index f670c9d..61c54d7 100644 --- a/mace_jax/run_train.py +++ b/mace_jax/run_train.py @@ -32,7 +32,10 @@ def main(): train_loader, valid_loader, test_loader, atomic_energies_dict, r_max = datasets() model_fn, params, num_message_passing = model( - r_max, atomic_energies_dict, train_loader.graphs, initialize_seed=seed + r_max=r_max, + atomic_energies_dict=atomic_energies_dict, + train_graphs=train_loader.graphs, + initialize_seed=seed, ) params = reload(params) diff --git a/mace_jax/tools/gin_model.py b/mace_jax/tools/gin_model.py index 2629062..e701b4d 100644 --- a/mace_jax/tools/gin_model.py +++ b/mace_jax/tools/gin_model.py @@ -80,11 +80,11 @@ def __call__(self, node_specie: jnp.ndarray) -> e3nn.IrrepsArray: @gin.configurable def model( + *, r_max: float, atomic_energies_dict: Dict[int, float] = None, train_graphs: List[jraph.GraphsTuple] = None, initialize_seed: Optional[int] = None, - *, scaling: Callable = None, atomic_energies: Union[str, np.ndarray, Dict[int, float]] = None, avg_num_neighbors: float = "average", @@ -166,11 +166,12 @@ def model( # check that num_species is consistent with the dataset if z_table is None: - for graph in train_graphs: - if not np.all(graph.nodes.species < num_species): - raise ValueError( - f"max(graph.nodes.species)={np.max(graph.nodes.species)} >= num_species={num_species}" - ) + if train_graphs is not None: + for graph in train_graphs: + if not np.all(graph.nodes.species < num_species): + raise ValueError( + f"max(graph.nodes.species)={np.max(graph.nodes.species)} >= num_species={num_species}" + ) else: if max(z_table.zs) >= num_species: raise ValueError( From a8f0cab46cf252e9c67039e657c8394dc95129ce Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 2 Feb 2023 19:32:27 -0500 Subject: [PATCH 5/7] fix --- mace_jax/tools/gin_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mace_jax/tools/gin_datasets.py b/mace_jax/tools/gin_datasets.py index 28b623e..c2db5cb 100644 --- a/mace_jax/tools/gin_datasets.py +++ b/mace_jax/tools/gin_datasets.py @@ -42,7 +42,7 @@ def datasets( """Load training and test dataset from xyz file""" atomic_energies_dict, all_train_configs = data.load_from_xyz( - file_path=train_path, + file_or_path=train_path, config_type_weights=config_type_weights, energy_key=energy_key, forces_key=forces_key, From f888c3bfbadedd45bfa3ea02815e9cb30dc8452a Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 2 Feb 2023 19:33:54 -0500 Subject: [PATCH 6/7] fix --- mace_jax/tools/gin_datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mace_jax/tools/gin_datasets.py b/mace_jax/tools/gin_datasets.py index c2db5cb..9e89084 100644 --- a/mace_jax/tools/gin_datasets.py +++ b/mace_jax/tools/gin_datasets.py @@ -57,7 +57,7 @@ def datasets( if valid_path is not None: _, valid_configs = data.load_from_xyz( - file_path=valid_path, + file_or_path=valid_path, config_type_weights=config_type_weights, energy_key=energy_key, forces_key=forces_key, @@ -89,7 +89,7 @@ def datasets( if test_path is not None: _, test_configs = data.load_from_xyz( - file_path=test_path, + file_or_path=test_path, config_type_weights=config_type_weights, energy_key=energy_key, forces_key=forces_key, From 989707eea7fabafd2d007c2d3f636158c6d97fb9 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 13 Feb 2023 13:57:47 -0500 Subject: [PATCH 7/7] Update README.md --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.md b/README.md index f32edbf..e4c56ff 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,14 @@ nox ## Installation +From github: + +```sh +pip install git+https://github.com/ACEsuit/mace-jax +``` + +Or locally: + ```sh python setup.py develop ```