-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CHGNet-matgl implementation #242
Conversation
# Conflicts: # matgl/apps/pes.py # matgl/utils/training.py
# Conflicts: # matgl/graph/data.py
# Conflicts: # matgl/utils/training.py
…nto chgnet # Conflicts: # tests/graph/test_data.py
# Conflicts: # matgl/graph/data.py
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #242 +/- ##
==========================================
- Coverage 98.80% 97.67% -1.13%
==========================================
Files 33 35 +2
Lines 2752 3136 +384
==========================================
+ Hits 2719 3063 +344
- Misses 33 73 +40 ☔ View full report in Codecov by Sentry. |
This looks great!! @kenko911 will follow up with some suggestions and we can merge soon. |
# Conflicts: # pretrained_models/CHGNet-MPtrj-2023.12.1-PES-2.7M/README.md # pretrained_models/CHGNet-MPtrj-2024.2.13-PES-11M/model.json # src/matgl/ext/ase.py # src/matgl/layers/__init__.py # src/matgl/layers/_core.py # src/matgl/layers/_graph_convolution.py # src/matgl/layers/_norm.py # src/matgl/models/_chgnet.py # src/matgl/utils/training.py # tests/graph/test_data.py # tests/layers/test_norm.py # tests/models/test_chgnet.py
final version of chgnet with full git history
WalkthroughThis update encompasses enhancements and refinements across various components of the MatGL framework, focusing on model training, graph operations, and API usability improvements. Key changes include the introduction of new normalization and graph convolution layers, modifications to model configurations for handling magnetic moments instead of site-wise properties, and updates to the ASE interface to support new calculations. Additionally, there are updates to documentation, testing, and example notebooks to align with the latest functionalities. Changes
This table summarizes the changes made across the MatGL framework, highlighting the focus on enhancing model capabilities and refining the computational tools for materials science research. Recent Review DetailsConfiguration used: .coderabbit.yaml Files selected for processing (1)
Files skipped from review as they are similar to previous changes (1)
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (invoked as PR comments)
Additionally, you can add CodeRabbit Configration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 23
Review Status
Configuration used: CodeRabbit UI
Files selected for processing (26)
- Mo.cif (1 hunks)
- Mo_hello.cif (1 hunks)
- Mo_relaxed.cif (1 hunks)
- pretrained_models/CHGNet-MPtrj-2023.12.1-PES-2.7M/README.md (1 hunks)
- pretrained_models/CHGNet-MPtrj-2023.12.1-PES-2.7M/model.json (1 hunks)
- pretrained_models/CHGNet-MPtrj-2024.2.13-PES-11M/README.md (1 hunks)
- pretrained_models/CHGNet-MPtrj-2024.2.13-PES-11M/model.json (1 hunks)
- pyproject.toml (1 hunks)
- src/matgl/apps/pes.py (1 hunks)
- src/matgl/ext/ase.py (3 hunks)
- src/matgl/graph/compute.py (2 hunks)
- src/matgl/graph/data.py (11 hunks)
- src/matgl/layers/init.py (1 hunks)
- src/matgl/layers/_core.py (4 hunks)
- src/matgl/layers/_graph_convolution.py (14 hunks)
- src/matgl/layers/_norm.py (1 hunks)
- src/matgl/models/init.py (1 hunks)
- src/matgl/models/_chgnet.py (1 hunks)
- src/matgl/utils/training.py (8 hunks)
- tests/ext/test_ase.py (1 hunks)
- tests/graph/test_data.py (24 hunks)
- tests/layers/test_core_and_embedding.py (4 hunks)
- tests/layers/test_norm.py (1 hunks)
- tests/models/test_chgnet.py (1 hunks)
- tests/utils/test_io.py (1 hunks)
- tests/utils/test_training.py (16 hunks)
Files skipped from review due to trivial changes (1)
- Mo.cif
Additional Context Used
Additional comments not posted (50)
src/matgl/models/__init__.py (1)
5-5
: The import ofCHGNet
is correctly added, making it accessible as part of thematgl.models
module.Mo_hello.cif (1)
1-28
: The CIF file for Mo follows the standard CIF format and correctly describes a simple cubic structure. The crystallographic data appears consistent and well-formatted.Mo_relaxed.cif (1)
1-28
: The CIF file for the relaxed Mo structure follows the standard CIF format and correctly describes a simple cubic structure. It's assumed that "relaxed" refers to some form of computational optimization or experimental measurement, although the file does not detail this process.src/matgl/layers/__init__.py (1)
9-30
: The import statements for new core layer classes and normalization classes are correctly added, making them accessible as part of thematgl.layers
module. This ensures that users can easily utilize these new functionalities.pretrained_models/CHGNet-MPtrj-2024.2.13-PES-11M/README.md (1)
3-3
: Consider revising "This model is a CHGNet universal potential trained from the Materials Project trajectory (MPtrj) dataset" to "This model is a CHGNet universal potential, trained using the Materials Project trajectory (MPtrj) dataset," for improved clarity.src/matgl/layers/_norm.py (1)
14-62
: The implementation ofGraphNorm
and the extension ofLayerNorm
are correctly done, adhering to best practices and the methodologies described in the referenced materials. These additions enhance the normalization capabilities of the framework.tests/layers/test_norm.py (1)
9-62
: The test cases forGraphNorm
are well-implemented, covering a variety of scenarios including batched and unbatched graphs, and both node and edge features. The use of deterministic setups and appropriate assertions ensures the reliability of the tests.tests/utils/test_io.py (1)
79-80
: The addition of a test case for loading theCHGNet-MPtrj-2024.2.13-PES-11M
model is correctly implemented and aligns with the PR's objectives.pretrained_models/CHGNet-MPtrj-2023.12.1-PES-2.7M/model.json (1)
1-66
: The model configuration forCHGNet-MPtrj-2023.12.1-PES-2.7M
is well-defined and aligns with the CHGNet design principles, including the use of normalization layers and specific cutoff distances.pretrained_models/CHGNet-MPtrj-2024.2.13-PES-11M/model.json (1)
1-67
: The model configuration forCHGNet-MPtrj-2024.2.13-PES-11M
correctly scales up the model dimensions while maintaining the CHGNet design principles.tests/models/test_chgnet.py (1)
44-57
: The addition of thetest_CHGNetCalculator
function is correctly implemented and tests key functionalities of the CHGNet model, including magnetic moments calculations.pyproject.toml (1)
18-19
: The addition of Bowen Deng and Luis Barroso-Luque to the list of authors is correctly formatted and acknowledges their contributions to the project.tests/ext/test_ase.py (1)
44-57
: Thetest_CHGNetCalculator
function correctly tests theCHGNetCalculator
functionality, including magnetic moments calculations, and aligns with the PR's objectives.src/matgl/apps/pes.py (1)
98-98
: The change to ensure device consistency in the matrix operation by explicitly specifying the device is a best practice in PyTorch and improves the robustness of the code.tests/layers/test_core_and_embedding.py (2)
49-55
: The addition of tests forMLP_norm
with both "layer" and "graph" normalization is correctly implemented and ensures the functionality works as expected under different configurations.
56-62
: The addition of tests forGatedMLP_norm
with both "layer" and "graph" normalization is correctly implemented and ensures the functionality works as expected under different configurations.src/matgl/graph/compute.py (2)
81-82
: The default value for thedirected
parameter in thecreate_line_graph
function is now explicitly documented asFalse
. This improves code readability and makes the function's behavior clearer to users.
233-233
: The change from a lambda function totorch.gt
for the pruning condition in_create_directed_line_graph
simplifies the code and potentially improves performance by using a more direct comparison method.src/matgl/layers/_core.py (3)
5-5
: Adding imports fordgl
,GraphNorm
, andLayerNorm
is necessary for the new classesMLP_norm
andGatedMLP_norm
that utilize these components. This ensures that the required functionalities are available within the file.
100-171
: TheMLP_norm
class introduces a multi-layer perceptron with optional normalization layers. This class is well-structured and provides flexibility in terms of normalization type and whether to normalize hidden layers. However, it's important to ensure that theg
parameter in theforward
method is always provided when graph normalization is used, as it's required for theGraphNorm
layer.
208-264
: TheGatedMLP_norm
class extends the concept of a gated multi-layer perceptron by incorporating normalization layers. This class effectively combines the functionalities ofMLP_norm
for both the main and gate networks, allowing for a more sophisticated model architecture. Similar toMLP_norm
, ensure that thegraph
parameter is correctly handled when graph normalization is applied.tests/graph/test_data.py (4)
4-4
: Addingimport shutil
is necessary for the new cleanup procedure usingshutil.rmtree
to remove the dataset save path after tests. This ensures proper cleanup and avoids potential side effects on subsequent tests.
24-30
: The addition ofsave_cache=False
inMGLDataset
initialization within tests is a good practice to prevent caching during test runs. This ensures that each test is run with a fresh dataset, avoiding potential issues with cached data affecting test outcomes.
94-94
: Usingshutil.rmtree
for cleanup after dataset tests is an effective way to ensure that any files or directories created during the test are properly removed. This helps maintain a clean test environment and prevents disk space from being consumed by temporary test data.
246-246
: Updatingcollate_fn
tocollate_fn_graph
inMGLDataLoader
calls aligns with the specific requirements of the tests being performed. This ensures that the correct collation function is used for graph data, which is crucial for the proper handling of batched graph inputs.src/matgl/utils/training.py (2)
6-6
: The import ofAny
and modification ofLiteral
in thetyping
import are appropriate for the added functionality. Ensure that these types are utilized effectively in the type annotations throughout the file.
469-507
: > 📝 NOTEThis review was outside the diff hunks and was mapped to the diff hunk with the greatest overlap. Original lines [454-504]
The
loss_fn
method's implementation to accommodateloss_params
,allow_missing_labels
, andsite_wise_target
is well-done. It's crucial to ensure thatloss_params
is properly documented, especially since it directly influences the loss calculation. Additionally, consider adding a brief comment explaining the logic behind handling missing labels and thesite_wise_target
parameter for future maintainability.+ # Apply additional parameters to the loss function if provided. + # Handle missing labels by filtering out NaN values if `allow_missing_labels` is True. + # Adjust the loss calculation based on `site_wise_target` to support different prediction modes.src/matgl/ext/ase.py (2)
125-125
: Extending theimplemented_properties
to include"magmoms"
is a necessary change to support the calculation of magnetic moments. This aligns with the PR's objective to enhance thePESCalculator
functionality.
148-148
: The addition ofself.compute_magmoms
based onpotential.calc_site_wise
is a logical way to determine whether magnetic moments should be calculated. Ensure thatcalc_site_wise
is properly documented in thePotential
class to clarify its role in enabling magnetic moments calculation.Verification successful
The search results indicate that
calc_site_wise
is indeed used within the context of thePotential
class, specifically in the filessrc/matgl/utils/training.py
andsrc/matgl/apps/pes.py
. The usage insrc/matgl/apps/pes.py
includes a brief documentation comment: "calc_site_wise: Enable site-wise property calculation." This suggests that there is at least some level of documentation regardingcalc_site_wise
and its role in enabling site-wise property calculation, which aligns with the review comment's concern about ensuring proper documentation for this attribute.Given this information, it can be concluded that the review comment is consistent with the codebase, as there is evidence of
calc_site_wise
being documented, at least to some extent, in relation to enabling magnetic moments calculation.Scripts Executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Verify the documentation of `calc_site_wise` in the Potential class. rg --type py "class Potential" -A 10 | grep -i "calc_site_wise"Length of output: 63
Script:
#!/bin/bash # Find the file or files containing the Potential class potential_files=$(rg --type py -l "class Potential") # Search for calc_site_wise within those files to check for documentation for file in $potential_files; do echo "Searching in $file" rg --type py "calc_site_wise" $file doneLength of output: 802
src/matgl/models/_chgnet.py (3)
188-193
: Excellent error handling for invalid activation types. This ensures that users are informed about the available options if they input an unsupported activation function. It's a good practice to guide users towards correct usage.
317-400
: The forward method is well-structured and covers the necessary steps for processing with CHGNet, including graph and line graph creation, feature embedding, message passing, and readout. However, it's important to ensure that all tensors are moved to the appropriate device (g.device
) to avoid device mismatches, especially in operations like tensor creation (torch.tensor
) and when using external tensors (state_attr
).
401-428
: Thepredict_structure
method provides a convenient way to predict properties directly from structures, which enhances the usability of the CHGNet model. However, it's crucial to ensure thatstate_feats
andgraph_converter
are properly documented in the method's docstring, including their types and roles, to guide users on how to use this method effectively.src/matgl/graph/data.py (5)
25-25
: The addition ofinclude_line_graph
andmultiple_values_per_target
parameters incollate_fn_graph
increases the function's flexibility. However, ensure that these new parameters are documented in the function's docstring to inform users about their purpose and usage.
155-162
: TheMGLDataset
class has been enhanced with new parameters likedirected_line_graph
,structures
,labels
, andsave_cache
. Ensure that these parameters are fully documented in the class docstring, including their types, default values, and descriptions, to guide users on how to use this class effectively.
306-368
: TheCHGNetDataset
class introduces a specialized dataset for CHGNet, including handling of structures, labels, and graph conversions. It's important to ensure that all attributes and methods are properly documented, especially theprocess
method, to guide users on how to extend or use this class. Additionally, consider adding error handling for cases where theconverter
is not provided but required for processing.
470-517
: TheOOMCHGNetDataset
class provides an out-of-memory approach to handling CHGNet datasets. While this approach can be beneficial for large datasets, it's crucial to document the limitations and requirements of this class, such as the need for pre-processed and chunked data files. Additionally, ensure that theload
and__getitem__
methods are optimized for performance to handle large datasets efficiently.
519-628
: TheChunkedCHGNetDataset
class introduces a way to handle large datasets in chunks, which can improve memory usage and performance. However, theprocess
method raisesNotImplementedError
, indicating that data processing is not supported for this class. It's important to document this limitation and provide guidance on how users should prepare their data before using this class. Additionally, consider implementing or suggesting tools for chunking datasets to facilitate the use of this class.tests/utils/test_training.py (13)
35-39
: Adding thesave_cache=False
parameter to theMGLDataset
instantiation is a good practice for tests to ensure that no state is preserved between test runs, which could lead to flaky tests. This change enhances the reliability of the test suite.
51-51
: The update to usecollate_fn_graph
for data collation is aligned with the specific needs of the test case. It's important to ensure that the chosen collate function matches the data structure and requirements of the model being tested.
103-103
: The addition ofsave_cache=False
in various dataset instantiations throughout the file is consistent and ensures test isolation. This is a good practice.
177-177
: The consistent use ofsave_cache=False
across all dataset instantiations in the test file is noted and approved for ensuring test reliability.
249-249
: The use ofsave_cache=False
is consistently applied across dataset instantiations, which is good for maintaining test isolation.
320-320
: The addition ofsave_cache=False
in the dataset instantiation is consistent with best practices for test isolation.
366-366
: The use ofsave_cache=False
in dataset instantiation is consistent and ensures that tests do not interfere with each other by caching data.
436-436
: The consistent application ofsave_cache=False
across dataset instantiations in this file is good practice for ensuring test isolation.
507-507
: The use ofsave_cache=False
in dataset instantiation is consistent with the approach taken throughout the test file to ensure test isolation.
577-577
: The addition ofsave_cache=False
in the dataset instantiation is consistent with the approach taken throughout the test file to ensure test isolation.
637-717
: The introduction of thetest_chgnet_training
function is a significant addition, aligning with the PR's objective to integrate CHGNet into the framework. This test covers various aspects of CHGNet's functionality, including dataset preparation with directed line graphs, model instantiation, and training/testing procedures. It's crucial to ensure that the test adequately covers the model's capabilities and edge cases. Consider adding more assertions to verify the model's predictions more thoroughly, such as checking the shape of the output or specific values beyond just being less than 0.
718-770
: Thetest_chgnet_training_without_m
function appears to test CHGNet training without magnetic moments. This is a valuable test case, ensuring that the model can handle datasets where certain features might be missing. However, the function name might be misleading as it suggests the absence of magnetic moments (m
), but the test setup does not explicitly demonstrate this scenario. Clarifying the test's purpose or adjusting the setup to reflect the absence of magnetic moments explicitly would improve readability and intent.
771-771
: Theteardown_class
method is a good practice for cleaning up after tests, ensuring that artifacts like logs do not persist and potentially affect subsequent tests.
tests/utils/test_training.py
Outdated
from matgl.graph.data import MGLDataLoader, MGLDataset, collate_fn_efs, collate_fn_efsm, collate_fn_graph | ||
from matgl.models import CHGNet, M3GNet, MEGNet, SO3Net, TensorNet |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The import statement includes models that are no longer used in the test cases (M3GNet
, MEGNet
, SO3Net
, TensorNet
). Since the PR focuses on integrating CHGNet, consider removing unused imports to keep the code clean and maintainable.
- from matgl.models import CHGNet, M3GNet, MEGNet, SO3Net, TensorNet
+ from matgl.models import CHGNet
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
from matgl.graph.data import MGLDataLoader, MGLDataset, collate_fn_efs, collate_fn_efsm, collate_fn_graph | |
from matgl.models import CHGNet, M3GNet, MEGNet, SO3Net, TensorNet | |
from matgl.graph.data import MGLDataLoader, MGLDataset, collate_fn_efs, collate_fn_efsm, collate_fn_graph | |
from matgl.models import CHGNet |
src/matgl/utils/training.py
Outdated
data_mean: average of training data | ||
data_std: standard deviation of training data | ||
loss: loss function used for training | ||
loss_params: parameters for loss function | ||
optimizer: optimizer for training | ||
scheduler: scheduler for training | ||
lr: learning rate for training | ||
decay_steps: number of steps for decaying learning rate | ||
decay_alpha: parameter determines the minimum learning rate. | ||
sync_dist: whether sync logging across all GPU workers or not | ||
allow_missing_labels: Whether to allow missing labels or not. | ||
These should be present in the dataset as torch.nans and will be skipped in computing the loss. | ||
site_wise_target: Whether to predict the absolute site-wise value of magmoms or adapt the loss function | ||
to predict the signed value breaking symmetry. If None given the loss function will be adapted. | ||
**kwargs: Passthrough to parent init. | ||
""" | ||
assert energy_weight >= 0, f"energy_weight has to be >=0. Got {energy_weight}!" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
📝 NOTE
This review was outside the diff hunks and was mapped to the diff hunk with the greatest overlap. Original lines [271-305]
Adding loss_params
, allow_missing_labels
, and site_wise_target
parameters to the constructor of PotentialLightningModule
significantly enhances the module's flexibility. However, it's crucial to validate these parameters, especially loss_params
and site_wise_target
, to ensure they are within expected ranges or values.
+ if loss_params is not None and not isinstance(loss_params, dict):
+ raise ValueError("`loss_params` must be a dictionary.")
+ if site_wise_target not in [None, "absolute", "symbreak"]:
+ raise ValueError("`site_wise_target` must be one of [None, 'absolute', 'symbreak'].")
def on_load_checkpoint(self, checkpoint: dict[str, Any]): | ||
"""# noqa: D200 | ||
hacky hacky hack to add missing keys to the state dict when changes are made. | ||
""" | ||
for key in self.state_dict(): | ||
if key not in checkpoint["state_dict"]: | ||
checkpoint["state_dict"][key] = self.state_dict()[key] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The on_load_checkpoint
method is a clever workaround for handling missing keys in the state dict due to changes in the model's architecture or parameters. However, it's important to log a warning when adding missing keys to inform the user of these adjustments.
+ if key not in checkpoint["state_dict"]:
+ logging.warning(f"Adding missing key to checkpoint state_dict: {key}")
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
def on_load_checkpoint(self, checkpoint: dict[str, Any]): | |
"""# noqa: D200 | |
hacky hacky hack to add missing keys to the state dict when changes are made. | |
""" | |
for key in self.state_dict(): | |
if key not in checkpoint["state_dict"]: | |
checkpoint["state_dict"][key] = self.state_dict()[key] | |
def on_load_checkpoint(self, checkpoint: dict[str, Any]): | |
"""# noqa: D200 | |
hacky hacky hack to add missing keys to the state dict when changes are made. | |
""" | |
for key in self.state_dict(): | |
if key not in checkpoint["state_dict"]: | |
logging.warning(f"Adding missing key to checkpoint state_dict: {key}") | |
checkpoint["state_dict"][key] = self.state_dict()[key] |
@@ -72,7 +78,7 @@ def edge_update_(self, graph: dgl.DGLGraph) -> Tensor: | |||
"""Perform edge update. | |||
|
|||
Args: | |||
graph: Input graph | |||
graph: Input g |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the edge_update_
method, the argument description for graph
is incomplete. It should clearly describe the expected input.
- graph: Input g
+ graph: Input DGLGraph object representing the graph structure.
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
graph: Input g | |
graph: Input DGLGraph object representing the graph structure. |
@@ -85,7 +91,7 @@ | |||
"""Perform node update. | |||
|
|||
Args: | |||
graph: Input graph | |||
graph: Input g |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly, in the node_update_
method, the argument description for graph
is incomplete and should be clarified.
- graph: Input g
+ graph: Input DGLGraph object representing the graph structure.
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
graph: Input g | |
graph: Input DGLGraph object representing the graph structure. |
src/matgl/graph/data.py
Outdated
if collate_fn is None: | ||
if "forces" not in train_data.dataset.labels: | ||
collate_fn = collate_fn_graph | ||
else: | ||
if "stresses" not in train_data.dataset.labels: | ||
collate_fn = partial(collate_fn_efs, include_stress=False) | ||
else: | ||
if "magmoms" not in train_data.dataset.labels: # noqa: SIM108 | ||
collate_fn = collate_fn_efs | ||
else: | ||
collate_fn = collate_fn_efsm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dynamic assignment of collate_fn
in MGLDataLoader
based on the presence of labels in the dataset is a smart approach to handle different types of data. However, this logic could become complex and hard to maintain as more conditions are added. Consider refactoring this into a separate function or using a more scalable approach to handle the assignment of collate_fn
.
|
||
This model is a CHGNet universal potential trained from the Materials Project trajectory (MPtrj) dataset | ||
that contains over 1.5 million structures with 89 elements. | ||
This Matgl implementation has slight modification from original pytorch implementation by adding directed edge updates. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace "pytorch" with "PyTorch" to adhere to the correct capitalization of the framework's name.
- original pytorch implementation
+ original PyTorch implementation
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
This Matgl implementation has slight modification from original pytorch implementation by adding directed edge updates. | |
This Matgl implementation has slight modification from original PyTorch implementation by adding directed edge updates. |
# Training dataset | ||
|
||
MPtrj-2022.9: Materials Project trajectory dataset that contains GGA and GGA+U static and relaxation calculations. | ||
- Train-Val-Test splitting with mp-id: 0.95 - 0.5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For numerical ranges, consider using an en dash (–) instead of a hyphen (-) for improved readability. For example, "0.95–0.5" instead of "0.95 - 0.5".
- 0.95 - 0.5
+ 0.95–0.5
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
- Train-Val-Test splitting with mp-id: 0.95 - 0.5 | |
- Train-Val-Test splitting with mp-id: 0.95–0.5 |
|
||
This model is a CHGNet universal potential trained from the Materials Project trajectory (MPtrj) dataset | ||
that contains over 1.5 million structures with 89 elements. | ||
This Matgl implementation has slight modification from original pytorch implementation by adding directed edge updates. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace "pytorch" with "PyTorch" to adhere to the correct capitalization of the framework's name.
- original pytorch implementation
+ original PyTorch implementation
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
This Matgl implementation has slight modification from original pytorch implementation by adding directed edge updates. | |
This Matgl implementation has slight modification from original PyTorch implementation by adding directed edge updates. |
# Training dataset | ||
|
||
MPtrj-2022.9: Materials Project trajectory dataset that contains GGA and GGA+U static and relaxation calculations. | ||
- Train-Val-Test splitting with mp-id: 0.9 - 0.5 - 0.5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For numerical ranges, consider using an en dash (–) instead of a hyphen (-) for improved readability. For example, "0.9–0.5–0.5" instead of "0.9 - 0.5 - 0.5".
- 0.9 - 0.5 - 0.5
+ 0.9–0.5–0.5
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
- Train-Val-Test splitting with mp-id: 0.9 - 0.5 - 0.5 | |
- Train-Val-Test splitting with mp-id: 0.9–0.5–0.5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Review Status
Configuration used: CodeRabbit UI
Files selected for processing (1)
- src/matgl/graph/data.py (10 hunks)
Additional Context Used
Additional comments not posted (2)
src/matgl/graph/data.py (2)
69-84
: The implementation ofcollate_fn_efsm
assumes that all items in the batch have the same keys, which might not always be the case. Consider iterating over all items inlabels
to check for the presence of "stresses" more robustly.
158-179
: > 📝 NOTEThis review was outside the diff hunks and was mapped to the diff hunk with the greatest overlap. Original lines [141-176]
The documentation for
MGLDataset
is thorough and informative. However, there's a minor inconsistency in the documentation style. For parameters likedirected_line_graph
, the description starts on the same line as the parameter name, while for others likeraw_dir
, it starts on a new line. Consider standardizing the documentation style for consistency.
@kenko911 Is this ready to be merged? You were supposed to review and provide comments. |
Hi @shyuep, I already discussed this with Bowen and I am currently working on it. I will make a PR by the end of this week. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
Out of diff range and nitpick comments (9)
src/matgl/graph/data.py (6)
25-27
: Consider adding type hints for the return values ofcollate_fn_graph
to improve code readability and maintainability.
Line range hint
46-66
: The functioncollate_fn_pes
handles multiple conditions and returns different types of data based on the flags. This could lead to maintenance challenges. Consider refactoring to separate functions or a class to handle these variations more cleanly.
Line range hint
81-109
: The dynamic assignment ofcollate_fn
based on dataset labels inMGLDataLoader
is complex. Consider simplifying this logic or encapsulating it in a method to improve readability and maintainability.
Line range hint
132-167
: The constructor ofMGLDataset
has a large number of parameters, which can be error-prone and hard to maintain. Consider using a data class or a configuration object to encapsulate these parameters.
Line range hint
202-221
: Theprocess
method inMGLDataset
is quite long and handles multiple responsibilities. Consider breaking it down into smaller, more focused methods to improve readability and maintainability.
240-254
: In thesave
method, the conditionif self.save_cache is False
could be more clearly written asif not self.save_cache
.src/matgl/utils/training.py (1)
Line range hint
267-517
: ThePotentialLightningModule
class is well-structured but complex due to handling multiple types of data (energy, force, stress, magmom). Consider refactoring to separate concerns, possibly by creating specialized classes for each data type.tests/graph/test_data.py (2)
75-75
: Consider using more descriptive names for dataset instances to improve code readability.
130-130
: Consider using more descriptive names for dataset instances to improve code readability.
] | ||
} | ||
], | ||
"outputs": [], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding error handling for the data fetching operation to manage potential issues with the URL or data format changes.
dataset = MGLDataset( | ||
structures=structures, | ||
converter=cry_graph, | ||
labels={"label": label}, | ||
clear_processed=True, | ||
save_cache=False, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider refactoring to reduce duplication of dataset initialization parameters across tests.
g1, lat1, state1, label1 = dataset[0] | ||
assert label1["label"] == label[0] | ||
assert g1.num_edges() == cry_graph.get_graph(LiFePO4)[0].num_edges() | ||
assert g1.num_nodes() == cry_graph.get_graph(LiFePO4)[0].num_nodes() | ||
assert np.allclose(lat1.detach().numpy(), structures[0].lattice.matrix) | ||
self.teardown_class() | ||
shutil.rmtree(f"{dataset.save_path}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add error handling for directory removal to prevent potential data loss.
@@ -148,6 +163,7 @@ | |||
assert np.shape(pes2["forces"])[0], 10 | |||
assert np.allclose(lat1.detach().numpy(), structures[0].lattice.matrix) | |||
assert np.allclose(lat2.detach().numpy(), structures[1].lattice.matrix) | |||
shutil.rmtree(f"{dataset.save_path}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add error handling for directory removal to prevent potential data loss.
@@ -184,8 +199,7 @@ | |||
assert g1.num_nodes() == cry_graph.get_graph(LiFePO4)[0].num_nodes() | |||
assert g2.num_edges() == cry_graph.get_graph(BaNiO3)[0].num_edges() | |||
assert g2.num_nodes() == cry_graph.get_graph(BaNiO3)[0].num_nodes() | |||
|
|||
self.teardown_class() | |||
shutil.rmtree(f"{dataset.save_path}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add error handling for directory removal to prevent potential data loss.
Hi @shyuep, I have restructured some MatGL modules and added more united tests for CHGNet implementations. I discussed this with Bowen and we both are happy about the current changes. Only 3 united tests fail due to the addition and refactoring of some variables in Potential class. When we perform the united test for loading pretrained PES models in MatGL, it fails with loading the previous model version from e.g. /home/runner/.cache/matgl/M3GNet-MP-2021.2.8-DIRECT-PES/model.json. I also updated the model veriosn of Potential class from 2 to 3 and I believe that it will be automatically fixed once we merge the changes to the main branch. Please have a look and let me know if any problems. Thanks!! |
Summary
CHGNet implementaion:
with two pretrained weights released
Implemented Functions
Todos
Summary by CodeRabbit