Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CHGNet-matgl implementation #242

Merged
merged 351 commits into from
May 6, 2024
Merged

CHGNet-matgl implementation #242

merged 351 commits into from
May 6, 2024

Conversation

bowen-bd
Copy link
Contributor

@bowen-bd bowen-bd commented Mar 20, 2024

Summary

CHGNet implementaion:
with two pretrained weights released

Implemented Functions

Todos

  • examples and README

Summary by CodeRabbit

  • New Features
    • Introduced new crystal structures for Molybdenum with detailed atomic positions and unit cell parameters.
    • Launched two versions of the CHGNet universal potential model for advanced materials science simulations.
    • Extended the functionality of the PESCalculator class to include magnetic moment calculations.
    • Introduced new dataset classes and functionalities to improve data handling and model training.
    • Added new core layer and normalization classes to enhance model architecture capabilities.
    • Expanded graph convolution capabilities with new classes and methods.
  • Enhancements
    • Updated the MatGL framework's author list.
    • Improved device consistency in matrix operations for potential energy surface calculations.
    • Clarified default parameter values and improved function implementations in graph computation.
  • Documentation
    • Updated READMEs for CHGNet models with configuration details and data preprocessing specifics.
  • Bug Fixes
    • Addressed issues in test cases, ensuring proper functionality of new features and models.
  • Chores
    • Various updates to improve codebase clarity and maintainability.

lbluque added 30 commits August 12, 2023 17:23
# Conflicts:
#	matgl/apps/pes.py
#	matgl/utils/training.py
# Conflicts:
#	matgl/graph/data.py
# Conflicts:
#	matgl/utils/training.py
…nto chgnet

# Conflicts:
#	tests/graph/test_data.py
# Conflicts:
#	matgl/graph/data.py
Copy link

codecov bot commented Mar 20, 2024

Codecov Report

Attention: Patch coverage is 87.90698% with 52 lines in your changes are missing coverage. Please review.

Project coverage is 97.67%. Comparing base (8ed58f9) to head (4a875e0).

❗ Current head 4a875e0 differs from pull request most recent head 49c5193. Consider uploading reports for the commit 49c5193 to get more accurate results

Files Patch % Lines
src/matgl/layers/_graph_convolution.py 88.43% 20 Missing ⚠️
src/matgl/utils/training.py 67.64% 11 Missing ⚠️
src/matgl/models/_chgnet.py 90.65% 10 Missing ⚠️
src/matgl/graph/data.py 84.09% 7 Missing ⚠️
src/matgl/layers/_core.py 89.74% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #242      +/-   ##
==========================================
- Coverage   98.80%   97.67%   -1.13%     
==========================================
  Files          33       35       +2     
  Lines        2752     3136     +384     
==========================================
+ Hits         2719     3063     +344     
- Misses         33       73      +40     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@shyuep
Copy link
Contributor

shyuep commented Mar 22, 2024

This looks great!! @kenko911 will follow up with some suggestions and we can merge soon.

lbluque and others added 2 commits March 28, 2024 19:46
# Conflicts:
#	pretrained_models/CHGNet-MPtrj-2023.12.1-PES-2.7M/README.md
#	pretrained_models/CHGNet-MPtrj-2024.2.13-PES-11M/model.json
#	src/matgl/ext/ase.py
#	src/matgl/layers/__init__.py
#	src/matgl/layers/_core.py
#	src/matgl/layers/_graph_convolution.py
#	src/matgl/layers/_norm.py
#	src/matgl/models/_chgnet.py
#	src/matgl/utils/training.py
#	tests/graph/test_data.py
#	tests/layers/test_norm.py
#	tests/models/test_chgnet.py
final version of chgnet with full git history
Copy link
Contributor

coderabbitai bot commented Mar 29, 2024

Walkthrough

This update encompasses enhancements and refinements across various components of the MatGL framework, focusing on model training, graph operations, and API usability improvements. Key changes include the introduction of new normalization and graph convolution layers, modifications to model configurations for handling magnetic moments instead of site-wise properties, and updates to the ASE interface to support new calculations. Additionally, there are updates to documentation, testing, and example notebooks to align with the latest functionalities.

Changes

Files Change Summary
Mo.cif, Mo_hello.cif, Mo_relaxed.cif Defined crystal structures for Molybdenum with specific atomic positions and unit cell parameters.
.../CHGNet-MPtrj-2023.12.1-PES-2.7M/README.md, .../model.json Introduced a CHGNet model with specifics on the PyTorch implementation and data preprocessing for PES calculations.
.../CHGNet-MPtrj-2024.2.13-PES-11M/README.md, .../model.json Updated CHGNet model trained with enhanced architecture settings for advanced PES calculations in material science.
.../M3GNet-MP-2021.2.8-DIRECT-PES/model.json Shifted focus from site-wise property calculations to magnetic moments in the M3GNet model configuration.
pyproject.toml Added Bowen Deng and Luis Barroso-Luque to the list of authors.
src/matgl/apps/pes.py, src/matgl/cli.py, src/matgl/ext/ase.py Major updates to parameter naming, function logic, and extended support for magnetic moment calculations.
src/matgl/graph/..., src/matgl/layers/..., src/matgl/models/... Added new classes and methods enhancing graph convolution capabilities and normalization processes.
examples/... Updated API keys, Python versions, and import statements in example notebooks to reflect the latest library configurations.
tests/... Enhanced testing for new functionalities, including magnetic moment calculations and updated model configurations.

This table summarizes the changes made across the MatGL framework, highlighting the focus on enhancing model capabilities and refining the computational tools for materials science research.


Recent Review Details

Configuration used: .coderabbit.yaml
Review profile: CHILL

Commits Files that changed from the base of the PR and between cc80907 and 49c5193.
Files selected for processing (1)
  • src/matgl/apps/pes.py (7 hunks)
Files skipped from review as they are similar to previous changes (1)
  • src/matgl/apps/pes.py

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

Share
Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai generate interesting stats about this repository and render them as a table.
    • @coderabbitai show all the console.log statements in this repository.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (invoked as PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger a review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai help to get help.

Additionally, you can add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.

CodeRabbit Configration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 23

Review Status

Configuration used: CodeRabbit UI

Commits Files that changed from the base of the PR and between ee9e987 and 392cc03.
Files selected for processing (26)
  • Mo.cif (1 hunks)
  • Mo_hello.cif (1 hunks)
  • Mo_relaxed.cif (1 hunks)
  • pretrained_models/CHGNet-MPtrj-2023.12.1-PES-2.7M/README.md (1 hunks)
  • pretrained_models/CHGNet-MPtrj-2023.12.1-PES-2.7M/model.json (1 hunks)
  • pretrained_models/CHGNet-MPtrj-2024.2.13-PES-11M/README.md (1 hunks)
  • pretrained_models/CHGNet-MPtrj-2024.2.13-PES-11M/model.json (1 hunks)
  • pyproject.toml (1 hunks)
  • src/matgl/apps/pes.py (1 hunks)
  • src/matgl/ext/ase.py (3 hunks)
  • src/matgl/graph/compute.py (2 hunks)
  • src/matgl/graph/data.py (11 hunks)
  • src/matgl/layers/init.py (1 hunks)
  • src/matgl/layers/_core.py (4 hunks)
  • src/matgl/layers/_graph_convolution.py (14 hunks)
  • src/matgl/layers/_norm.py (1 hunks)
  • src/matgl/models/init.py (1 hunks)
  • src/matgl/models/_chgnet.py (1 hunks)
  • src/matgl/utils/training.py (8 hunks)
  • tests/ext/test_ase.py (1 hunks)
  • tests/graph/test_data.py (24 hunks)
  • tests/layers/test_core_and_embedding.py (4 hunks)
  • tests/layers/test_norm.py (1 hunks)
  • tests/models/test_chgnet.py (1 hunks)
  • tests/utils/test_io.py (1 hunks)
  • tests/utils/test_training.py (16 hunks)
Files skipped from review due to trivial changes (1)
  • Mo.cif
Additional Context Used
Additional comments not posted (50)
src/matgl/models/__init__.py (1)

5-5: The import of CHGNet is correctly added, making it accessible as part of the matgl.models module.

Mo_hello.cif (1)

1-28: The CIF file for Mo follows the standard CIF format and correctly describes a simple cubic structure. The crystallographic data appears consistent and well-formatted.

Mo_relaxed.cif (1)

1-28: The CIF file for the relaxed Mo structure follows the standard CIF format and correctly describes a simple cubic structure. It's assumed that "relaxed" refers to some form of computational optimization or experimental measurement, although the file does not detail this process.

src/matgl/layers/__init__.py (1)

9-30: The import statements for new core layer classes and normalization classes are correctly added, making them accessible as part of the matgl.layers module. This ensures that users can easily utilize these new functionalities.

pretrained_models/CHGNet-MPtrj-2024.2.13-PES-11M/README.md (1)

3-3: Consider revising "This model is a CHGNet universal potential trained from the Materials Project trajectory (MPtrj) dataset" to "This model is a CHGNet universal potential, trained using the Materials Project trajectory (MPtrj) dataset," for improved clarity.

src/matgl/layers/_norm.py (1)

14-62: The implementation of GraphNorm and the extension of LayerNorm are correctly done, adhering to best practices and the methodologies described in the referenced materials. These additions enhance the normalization capabilities of the framework.

tests/layers/test_norm.py (1)

9-62: The test cases for GraphNorm are well-implemented, covering a variety of scenarios including batched and unbatched graphs, and both node and edge features. The use of deterministic setups and appropriate assertions ensures the reliability of the tests.

tests/utils/test_io.py (1)

79-80: The addition of a test case for loading the CHGNet-MPtrj-2024.2.13-PES-11M model is correctly implemented and aligns with the PR's objectives.

pretrained_models/CHGNet-MPtrj-2023.12.1-PES-2.7M/model.json (1)

1-66: The model configuration for CHGNet-MPtrj-2023.12.1-PES-2.7M is well-defined and aligns with the CHGNet design principles, including the use of normalization layers and specific cutoff distances.

pretrained_models/CHGNet-MPtrj-2024.2.13-PES-11M/model.json (1)

1-67: The model configuration for CHGNet-MPtrj-2024.2.13-PES-11M correctly scales up the model dimensions while maintaining the CHGNet design principles.

tests/models/test_chgnet.py (1)

44-57: The addition of the test_CHGNetCalculator function is correctly implemented and tests key functionalities of the CHGNet model, including magnetic moments calculations.

pyproject.toml (1)

18-19: The addition of Bowen Deng and Luis Barroso-Luque to the list of authors is correctly formatted and acknowledges their contributions to the project.

tests/ext/test_ase.py (1)

44-57: The test_CHGNetCalculator function correctly tests the CHGNetCalculator functionality, including magnetic moments calculations, and aligns with the PR's objectives.

src/matgl/apps/pes.py (1)

98-98: The change to ensure device consistency in the matrix operation by explicitly specifying the device is a best practice in PyTorch and improves the robustness of the code.

tests/layers/test_core_and_embedding.py (2)

49-55: The addition of tests for MLP_norm with both "layer" and "graph" normalization is correctly implemented and ensures the functionality works as expected under different configurations.


56-62: The addition of tests for GatedMLP_norm with both "layer" and "graph" normalization is correctly implemented and ensures the functionality works as expected under different configurations.

src/matgl/graph/compute.py (2)

81-82: The default value for the directed parameter in the create_line_graph function is now explicitly documented as False. This improves code readability and makes the function's behavior clearer to users.


233-233: The change from a lambda function to torch.gt for the pruning condition in _create_directed_line_graph simplifies the code and potentially improves performance by using a more direct comparison method.

src/matgl/layers/_core.py (3)

5-5: Adding imports for dgl, GraphNorm, and LayerNorm is necessary for the new classes MLP_norm and GatedMLP_norm that utilize these components. This ensures that the required functionalities are available within the file.


100-171: The MLP_norm class introduces a multi-layer perceptron with optional normalization layers. This class is well-structured and provides flexibility in terms of normalization type and whether to normalize hidden layers. However, it's important to ensure that the g parameter in the forward method is always provided when graph normalization is used, as it's required for the GraphNorm layer.


208-264: The GatedMLP_norm class extends the concept of a gated multi-layer perceptron by incorporating normalization layers. This class effectively combines the functionalities of MLP_norm for both the main and gate networks, allowing for a more sophisticated model architecture. Similar to MLP_norm, ensure that the graph parameter is correctly handled when graph normalization is applied.

tests/graph/test_data.py (4)

4-4: Adding import shutil is necessary for the new cleanup procedure using shutil.rmtree to remove the dataset save path after tests. This ensures proper cleanup and avoids potential side effects on subsequent tests.


24-30: The addition of save_cache=False in MGLDataset initialization within tests is a good practice to prevent caching during test runs. This ensures that each test is run with a fresh dataset, avoiding potential issues with cached data affecting test outcomes.


94-94: Using shutil.rmtree for cleanup after dataset tests is an effective way to ensure that any files or directories created during the test are properly removed. This helps maintain a clean test environment and prevents disk space from being consumed by temporary test data.


246-246: Updating collate_fn to collate_fn_graph in MGLDataLoader calls aligns with the specific requirements of the tests being performed. This ensures that the correct collation function is used for graph data, which is crucial for the proper handling of batched graph inputs.

src/matgl/utils/training.py (2)

6-6: The import of Any and modification of Literal in the typing import are appropriate for the added functionality. Ensure that these types are utilized effectively in the type annotations throughout the file.


469-507: > 📝 NOTE

This review was outside the diff hunks and was mapped to the diff hunk with the greatest overlap. Original lines [454-504]

The loss_fn method's implementation to accommodate loss_params, allow_missing_labels, and site_wise_target is well-done. It's crucial to ensure that loss_params is properly documented, especially since it directly influences the loss calculation. Additionally, consider adding a brief comment explaining the logic behind handling missing labels and the site_wise_target parameter for future maintainability.

+        # Apply additional parameters to the loss function if provided.
+        # Handle missing labels by filtering out NaN values if `allow_missing_labels` is True.
+        # Adjust the loss calculation based on `site_wise_target` to support different prediction modes.
src/matgl/ext/ase.py (2)

125-125: Extending the implemented_properties to include "magmoms" is a necessary change to support the calculation of magnetic moments. This aligns with the PR's objective to enhance the PESCalculator functionality.


148-148: The addition of self.compute_magmoms based on potential.calc_site_wise is a logical way to determine whether magnetic moments should be calculated. Ensure that calc_site_wise is properly documented in the Potential class to clarify its role in enabling magnetic moments calculation.

Verification successful

The search results indicate that calc_site_wise is indeed used within the context of the Potential class, specifically in the files src/matgl/utils/training.py and src/matgl/apps/pes.py. The usage in src/matgl/apps/pes.py includes a brief documentation comment: "calc_site_wise: Enable site-wise property calculation." This suggests that there is at least some level of documentation regarding calc_site_wise and its role in enabling site-wise property calculation, which aligns with the review comment's concern about ensuring proper documentation for this attribute.

Given this information, it can be concluded that the review comment is consistent with the codebase, as there is evidence of calc_site_wise being documented, at least to some extent, in relation to enabling magnetic moments calculation.

Scripts Executed

The following scripts were executed for the analysis:


Script:

#!/bin/bash
# Verify the documentation of `calc_site_wise` in the Potential class.
rg --type py "class Potential" -A 10 | grep -i "calc_site_wise"

Length of output: 63



Script:

#!/bin/bash
# Find the file or files containing the Potential class
potential_files=$(rg --type py -l "class Potential")

# Search for calc_site_wise within those files to check for documentation
for file in $potential_files; do
    echo "Searching in $file"
    rg --type py "calc_site_wise" $file
done

Length of output: 802

src/matgl/models/_chgnet.py (3)

188-193: Excellent error handling for invalid activation types. This ensures that users are informed about the available options if they input an unsupported activation function. It's a good practice to guide users towards correct usage.


317-400: The forward method is well-structured and covers the necessary steps for processing with CHGNet, including graph and line graph creation, feature embedding, message passing, and readout. However, it's important to ensure that all tensors are moved to the appropriate device (g.device) to avoid device mismatches, especially in operations like tensor creation (torch.tensor) and when using external tensors (state_attr).


401-428: The predict_structure method provides a convenient way to predict properties directly from structures, which enhances the usability of the CHGNet model. However, it's crucial to ensure that state_feats and graph_converter are properly documented in the method's docstring, including their types and roles, to guide users on how to use this method effectively.

src/matgl/graph/data.py (5)

25-25: The addition of include_line_graph and multiple_values_per_target parameters in collate_fn_graph increases the function's flexibility. However, ensure that these new parameters are documented in the function's docstring to inform users about their purpose and usage.


155-162: The MGLDataset class has been enhanced with new parameters like directed_line_graph, structures, labels, and save_cache. Ensure that these parameters are fully documented in the class docstring, including their types, default values, and descriptions, to guide users on how to use this class effectively.


306-368: The CHGNetDataset class introduces a specialized dataset for CHGNet, including handling of structures, labels, and graph conversions. It's important to ensure that all attributes and methods are properly documented, especially the process method, to guide users on how to extend or use this class. Additionally, consider adding error handling for cases where the converter is not provided but required for processing.


470-517: The OOMCHGNetDataset class provides an out-of-memory approach to handling CHGNet datasets. While this approach can be beneficial for large datasets, it's crucial to document the limitations and requirements of this class, such as the need for pre-processed and chunked data files. Additionally, ensure that the load and __getitem__ methods are optimized for performance to handle large datasets efficiently.


519-628: The ChunkedCHGNetDataset class introduces a way to handle large datasets in chunks, which can improve memory usage and performance. However, the process method raises NotImplementedError, indicating that data processing is not supported for this class. It's important to document this limitation and provide guidance on how users should prepare their data before using this class. Additionally, consider implementing or suggesting tools for chunking datasets to facilitate the use of this class.

tests/utils/test_training.py (13)

35-39: Adding the save_cache=False parameter to the MGLDataset instantiation is a good practice for tests to ensure that no state is preserved between test runs, which could lead to flaky tests. This change enhances the reliability of the test suite.


51-51: The update to use collate_fn_graph for data collation is aligned with the specific needs of the test case. It's important to ensure that the chosen collate function matches the data structure and requirements of the model being tested.


103-103: The addition of save_cache=False in various dataset instantiations throughout the file is consistent and ensures test isolation. This is a good practice.


177-177: The consistent use of save_cache=False across all dataset instantiations in the test file is noted and approved for ensuring test reliability.


249-249: The use of save_cache=False is consistently applied across dataset instantiations, which is good for maintaining test isolation.


320-320: The addition of save_cache=False in the dataset instantiation is consistent with best practices for test isolation.


366-366: The use of save_cache=False in dataset instantiation is consistent and ensures that tests do not interfere with each other by caching data.


436-436: The consistent application of save_cache=False across dataset instantiations in this file is good practice for ensuring test isolation.


507-507: The use of save_cache=False in dataset instantiation is consistent with the approach taken throughout the test file to ensure test isolation.


577-577: The addition of save_cache=False in the dataset instantiation is consistent with the approach taken throughout the test file to ensure test isolation.


637-717: The introduction of the test_chgnet_training function is a significant addition, aligning with the PR's objective to integrate CHGNet into the framework. This test covers various aspects of CHGNet's functionality, including dataset preparation with directed line graphs, model instantiation, and training/testing procedures. It's crucial to ensure that the test adequately covers the model's capabilities and edge cases. Consider adding more assertions to verify the model's predictions more thoroughly, such as checking the shape of the output or specific values beyond just being less than 0.


718-770: The test_chgnet_training_without_m function appears to test CHGNet training without magnetic moments. This is a valuable test case, ensuring that the model can handle datasets where certain features might be missing. However, the function name might be misleading as it suggests the absence of magnetic moments (m), but the test setup does not explicitly demonstrate this scenario. Clarifying the test's purpose or adjusting the setup to reflect the absence of magnetic moments explicitly would improve readability and intent.


771-771: The teardown_class method is a good practice for cleaning up after tests, ensuring that artifacts like logs do not persist and potentially affect subsequent tests.

Comment on lines 15 to 16
from matgl.graph.data import MGLDataLoader, MGLDataset, collate_fn_efs, collate_fn_efsm, collate_fn_graph
from matgl.models import CHGNet, M3GNet, MEGNet, SO3Net, TensorNet
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The import statement includes models that are no longer used in the test cases (M3GNet, MEGNet, SO3Net, TensorNet). Since the PR focuses on integrating CHGNet, consider removing unused imports to keep the code clean and maintainable.

- from matgl.models import CHGNet, M3GNet, MEGNet, SO3Net, TensorNet
+ from matgl.models import CHGNet

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
from matgl.graph.data import MGLDataLoader, MGLDataset, collate_fn_efs, collate_fn_efsm, collate_fn_graph
from matgl.models import CHGNet, M3GNet, MEGNet, SO3Net, TensorNet
from matgl.graph.data import MGLDataLoader, MGLDataset, collate_fn_efs, collate_fn_efsm, collate_fn_graph
from matgl.models import CHGNet

Comment on lines 293 to 309
data_mean: average of training data
data_std: standard deviation of training data
loss: loss function used for training
loss_params: parameters for loss function
optimizer: optimizer for training
scheduler: scheduler for training
lr: learning rate for training
decay_steps: number of steps for decaying learning rate
decay_alpha: parameter determines the minimum learning rate.
sync_dist: whether sync logging across all GPU workers or not
allow_missing_labels: Whether to allow missing labels or not.
These should be present in the dataset as torch.nans and will be skipped in computing the loss.
site_wise_target: Whether to predict the absolute site-wise value of magmoms or adapt the loss function
to predict the signed value breaking symmetry. If None given the loss function will be adapted.
**kwargs: Passthrough to parent init.
"""
assert energy_weight >= 0, f"energy_weight has to be >=0. Got {energy_weight}!"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📝 NOTE
This review was outside the diff hunks and was mapped to the diff hunk with the greatest overlap. Original lines [271-305]

Adding loss_params, allow_missing_labels, and site_wise_target parameters to the constructor of PotentialLightningModule significantly enhances the module's flexibility. However, it's crucial to validate these parameters, especially loss_params and site_wise_target, to ensure they are within expected ranges or values.

+        if loss_params is not None and not isinstance(loss_params, dict):
+            raise ValueError("`loss_params` must be a dictionary.")
+        if site_wise_target not in [None, "absolute", "symbreak"]:
+            raise ValueError("`site_wise_target` must be one of [None, 'absolute', 'symbreak'].")

Comment on lines +352 to +358
def on_load_checkpoint(self, checkpoint: dict[str, Any]):
"""# noqa: D200
hacky hacky hack to add missing keys to the state dict when changes are made.
"""
for key in self.state_dict():
if key not in checkpoint["state_dict"]:
checkpoint["state_dict"][key] = self.state_dict()[key]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The on_load_checkpoint method is a clever workaround for handling missing keys in the state dict due to changes in the model's architecture or parameters. However, it's important to log a warning when adding missing keys to inform the user of these adjustments.

+            if key not in checkpoint["state_dict"]:
+                logging.warning(f"Adding missing key to checkpoint state_dict: {key}")

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
def on_load_checkpoint(self, checkpoint: dict[str, Any]):
"""# noqa: D200
hacky hacky hack to add missing keys to the state dict when changes are made.
"""
for key in self.state_dict():
if key not in checkpoint["state_dict"]:
checkpoint["state_dict"][key] = self.state_dict()[key]
def on_load_checkpoint(self, checkpoint: dict[str, Any]):
"""# noqa: D200
hacky hacky hack to add missing keys to the state dict when changes are made.
"""
for key in self.state_dict():
if key not in checkpoint["state_dict"]:
logging.warning(f"Adding missing key to checkpoint state_dict: {key}")
checkpoint["state_dict"][key] = self.state_dict()[key]

@@ -72,7 +78,7 @@ def edge_update_(self, graph: dgl.DGLGraph) -> Tensor:
"""Perform edge update.

Args:
graph: Input graph
graph: Input g
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the edge_update_ method, the argument description for graph is incomplete. It should clearly describe the expected input.

-            graph: Input g
+            graph: Input DGLGraph object representing the graph structure.

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
graph: Input g
graph: Input DGLGraph object representing the graph structure.

@@ -85,7 +91,7 @@
"""Perform node update.

Args:
graph: Input graph
graph: Input g
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, in the node_update_ method, the argument description for graph is incomplete and should be clarified.

-            graph: Input g
+            graph: Input DGLGraph object representing the graph structure.

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
graph: Input g
graph: Input DGLGraph object representing the graph structure.

Comment on lines 122 to 132
if collate_fn is None:
if "forces" not in train_data.dataset.labels:
collate_fn = collate_fn_graph
else:
if "stresses" not in train_data.dataset.labels:
collate_fn = partial(collate_fn_efs, include_stress=False)
else:
if "magmoms" not in train_data.dataset.labels: # noqa: SIM108
collate_fn = collate_fn_efs
else:
collate_fn = collate_fn_efsm
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dynamic assignment of collate_fn in MGLDataLoader based on the presence of labels in the dataset is a smart approach to handle different types of data. However, this logic could become complex and hard to maintain as more conditions are added. Consider refactoring this into a separate function or using a more scalable approach to handle the assignment of collate_fn.


This model is a CHGNet universal potential trained from the Materials Project trajectory (MPtrj) dataset
that contains over 1.5 million structures with 89 elements.
This Matgl implementation has slight modification from original pytorch implementation by adding directed edge updates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace "pytorch" with "PyTorch" to adhere to the correct capitalization of the framework's name.

- original pytorch implementation
+ original PyTorch implementation

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
This Matgl implementation has slight modification from original pytorch implementation by adding directed edge updates.
This Matgl implementation has slight modification from original PyTorch implementation by adding directed edge updates.

# Training dataset

MPtrj-2022.9: Materials Project trajectory dataset that contains GGA and GGA+U static and relaxation calculations.
- Train-Val-Test splitting with mp-id: 0.95 - 0.5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For numerical ranges, consider using an en dash (–) instead of a hyphen (-) for improved readability. For example, "0.95–0.5" instead of "0.95 - 0.5".

- 0.95 - 0.5
+ 0.95–0.5

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
- Train-Val-Test splitting with mp-id: 0.95 - 0.5
- Train-Val-Test splitting with mp-id: 0.950.5


This model is a CHGNet universal potential trained from the Materials Project trajectory (MPtrj) dataset
that contains over 1.5 million structures with 89 elements.
This Matgl implementation has slight modification from original pytorch implementation by adding directed edge updates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace "pytorch" with "PyTorch" to adhere to the correct capitalization of the framework's name.

- original pytorch implementation
+ original PyTorch implementation

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
This Matgl implementation has slight modification from original pytorch implementation by adding directed edge updates.
This Matgl implementation has slight modification from original PyTorch implementation by adding directed edge updates.

# Training dataset

MPtrj-2022.9: Materials Project trajectory dataset that contains GGA and GGA+U static and relaxation calculations.
- Train-Val-Test splitting with mp-id: 0.9 - 0.5 - 0.5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For numerical ranges, consider using an en dash (–) instead of a hyphen (-) for improved readability. For example, "0.9–0.5–0.5" instead of "0.9 - 0.5 - 0.5".

- 0.9 - 0.5 - 0.5
+ 0.9–0.5–0.5

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
- Train-Val-Test splitting with mp-id: 0.9 - 0.5 - 0.5
- Train-Val-Test splitting with mp-id: 0.90.50.5

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Review Status

Configuration used: CodeRabbit UI

Commits Files that changed from the base of the PR and between 392cc03 and 4985006.
Files selected for processing (1)
  • src/matgl/graph/data.py (10 hunks)
Additional Context Used
Additional comments not posted (2)
src/matgl/graph/data.py (2)

69-84: The implementation of collate_fn_efsm assumes that all items in the batch have the same keys, which might not always be the case. Consider iterating over all items in labels to check for the presence of "stresses" more robustly.


158-179: > 📝 NOTE

This review was outside the diff hunks and was mapped to the diff hunk with the greatest overlap. Original lines [141-176]

The documentation for MGLDataset is thorough and informative. However, there's a minor inconsistency in the documentation style. For parameters like directed_line_graph, the description starts on the same line as the parameter name, while for others like raw_dir, it starts on a new line. Consider standardizing the documentation style for consistency.

@shyuep
Copy link
Contributor

shyuep commented May 2, 2024

@kenko911 Is this ready to be merged? You were supposed to review and provide comments.

@kenko911
Copy link
Contributor

kenko911 commented May 2, 2024

Hi @shyuep, I already discussed this with Bowen and I am currently working on it. I will make a PR by the end of this week.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Out of diff range and nitpick comments (9)
src/matgl/graph/data.py (6)

25-27: Consider adding type hints for the return values of collate_fn_graph to improve code readability and maintainability.


Line range hint 46-66: The function collate_fn_pes handles multiple conditions and returns different types of data based on the flags. This could lead to maintenance challenges. Consider refactoring to separate functions or a class to handle these variations more cleanly.


Line range hint 81-109: The dynamic assignment of collate_fn based on dataset labels in MGLDataLoader is complex. Consider simplifying this logic or encapsulating it in a method to improve readability and maintainability.


Line range hint 132-167: The constructor of MGLDataset has a large number of parameters, which can be error-prone and hard to maintain. Consider using a data class or a configuration object to encapsulate these parameters.


Line range hint 202-221: The process method in MGLDataset is quite long and handles multiple responsibilities. Consider breaking it down into smaller, more focused methods to improve readability and maintainability.


240-254: In the save method, the condition if self.save_cache is False could be more clearly written as if not self.save_cache.

src/matgl/utils/training.py (1)

Line range hint 267-517: The PotentialLightningModule class is well-structured but complex due to handling multiple types of data (energy, force, stress, magmom). Consider refactoring to separate concerns, possibly by creating specialized classes for each data type.

tests/graph/test_data.py (2)

75-75: Consider using more descriptive names for dataset instances to improve code readability.


130-130: Consider using more descriptive names for dataset instances to improve code readability.

]
}
],
"outputs": [],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding error handling for the data fetching operation to manage potential issues with the URL or data format changes.

Comment on lines +24 to +30
dataset = MGLDataset(
structures=structures,
converter=cry_graph,
labels={"label": label},
clear_processed=True,
save_cache=False,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider refactoring to reduce duplication of dataset initialization parameters across tests.

g1, lat1, state1, label1 = dataset[0]
assert label1["label"] == label[0]
assert g1.num_edges() == cry_graph.get_graph(LiFePO4)[0].num_edges()
assert g1.num_nodes() == cry_graph.get_graph(LiFePO4)[0].num_nodes()
assert np.allclose(lat1.detach().numpy(), structures[0].lattice.matrix)
self.teardown_class()
shutil.rmtree(f"{dataset.save_path}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add error handling for directory removal to prevent potential data loss.

@@ -148,6 +163,7 @@
assert np.shape(pes2["forces"])[0], 10
assert np.allclose(lat1.detach().numpy(), structures[0].lattice.matrix)
assert np.allclose(lat2.detach().numpy(), structures[1].lattice.matrix)
shutil.rmtree(f"{dataset.save_path}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add error handling for directory removal to prevent potential data loss.

@@ -184,8 +199,7 @@
assert g1.num_nodes() == cry_graph.get_graph(LiFePO4)[0].num_nodes()
assert g2.num_edges() == cry_graph.get_graph(BaNiO3)[0].num_edges()
assert g2.num_nodes() == cry_graph.get_graph(BaNiO3)[0].num_nodes()

self.teardown_class()
shutil.rmtree(f"{dataset.save_path}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add error handling for directory removal to prevent potential data loss.

@kenko911
Copy link
Contributor

kenko911 commented May 6, 2024

Hi @shyuep, I have restructured some MatGL modules and added more united tests for CHGNet implementations. I discussed this with Bowen and we both are happy about the current changes. Only 3 united tests fail due to the addition and refactoring of some variables in Potential class. When we perform the united test for loading pretrained PES models in MatGL, it fails with loading the previous model version from e.g. /home/runner/.cache/matgl/M3GNet-MP-2021.2.8-DIRECT-PES/model.json. I also updated the model veriosn of Potential class from 2 to 3 and I believe that it will be automatically fixed once we merge the changes to the main branch. Please have a look and let me know if any problems. Thanks!!

@shyuep shyuep merged commit 3d94dd4 into materialsvirtuallab:main May 6, 2024
2 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants