Skip to content

Commit

Permalink
add dtype option to create lammps model
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Oct 2, 2024
1 parent 4bfc7a0 commit 6b7d9c9
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion mace/cli/create_lammps_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ def parse_args():
help="Head of the model to be converted to LAMMPS",
default=None,
)
parser.add_argument(
"--dtype",
type=str,
nargs="?",
help="Data type of the model to be converted to LAMMPS",
default="float64",
)
return parser.parse_args()


Expand Down Expand Up @@ -58,7 +65,11 @@ def main():
model_path,
map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
model = model.double().to("cpu")
if args.dtype == "float64":
model = model.double().to("cpu")
elif args.dtype == "float32":
print("Converting model to float32, this may cause loss of precision.")
model = model.float().to("cpu")

if args.head is None:
head = select_head(model)
Expand Down

0 comments on commit 6b7d9c9

Please sign in to comment.