Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate paddle backend in devel branch #4157

Closed
Closed
Show file tree
Hide file tree
Changes from 65 commits
Commits
Show all changes
106 commits
Select commit Hold shift + click to select a range
59b9af5
Add paddle backend code(WIP)
HydrogenSulfate Sep 3, 2024
7f32618
update runnable code with water/se_e2_a
HydrogenSulfate Sep 4, 2024
217bf36
update correct water/se_e2_a code
HydrogenSulfate Sep 4, 2024
0ad720d
fix extra state
HydrogenSulfate Sep 5, 2024
1d7b0d1
Fix typo and bugs
HydrogenSulfate Sep 6, 2024
f6eeef6
fix concat
HydrogenSulfate Sep 6, 2024
ec021f7
update fix code
HydrogenSulfate Sep 7, 2024
55f71f6
add normalize composite impl
HydrogenSulfate Sep 9, 2024
ebdbed2
use prim function instead of vanilla function, supporting
HydrogenSulfate Sep 10, 2024
e0aeb73
Merge branch 'devel' into add_paddle_backend
HydrogenSulfate Sep 10, 2024
0896c90
Merge branch 'devel' into add_paddle_backend
HydrogenSulfate Sep 13, 2024
2e79d68
update inference code(WIP)
HydrogenSulfate Sep 14, 2024
482d588
update CMAKE code(WIP)
HydrogenSulfate Sep 18, 2024
a3c4663
update CMAKE
HydrogenSulfate Sep 18, 2024
2b4832a
add build script
HydrogenSulfate Sep 18, 2024
f1100e4
Update water/se_e2_a + LAMMPS code
HydrogenSulfate Sep 23, 2024
3068869
Merge branch 'devel' into add_paddle_backend
HydrogenSulfate Sep 23, 2024
68a2d62
fix get_pd_version
HydrogenSulfate Sep 23, 2024
ff7e0ef
fix read_env.py
HydrogenSulfate Sep 23, 2024
023ba53
fix suffix
HydrogenSulfate Sep 23, 2024
8a1834f
fix pd/cxx_op.py
HydrogenSulfate Sep 23, 2024
396bd54
fix main.py
HydrogenSulfate Sep 23, 2024
66734bc
fix get_item_paddle
HydrogenSulfate Sep 23, 2024
ba02ae8
restore in.lammps
HydrogenSulfate Sep 23, 2024
40157dd
restore pyproject.toml
HydrogenSulfate Sep 23, 2024
8d53aec
simplify CMAKE
HydrogenSulfate Sep 23, 2024
e39d466
restore c_api.cc
HydrogenSulfate Sep 23, 2024
b97571e
fix bugs
HydrogenSulfate Sep 24, 2024
50092c6
change pt -> pd
HydrogenSulfate Sep 24, 2024
e02dd11
refactor DeepPotPD.cc
HydrogenSulfate Sep 24, 2024
8343077
update commonPD.h
HydrogenSulfate Sep 24, 2024
09c54f3
remove boost
HydrogenSulfate Sep 24, 2024
bc854b2
refine code
HydrogenSulfate Sep 25, 2024
892fd80
update refined infer code
HydrogenSulfate Sep 26, 2024
04b9064
Merge branch 'devel' into add_paddle_backend
HydrogenSulfate Sep 27, 2024
b3a6408
remove redundant code
HydrogenSulfate Sep 27, 2024
15a7e75
refine docstring of get_buffer
HydrogenSulfate Sep 27, 2024
c792563
Merge branch 'devel' into add_paddle_backend
HydrogenSulfate Sep 27, 2024
72241ea
update pd version code
HydrogenSulfate Sep 27, 2024
8694476
restore non related files
HydrogenSulfate Sep 27, 2024
e246a34
add paddle to related docs
HydrogenSulfate Sep 27, 2024
5ee8bcf
optimize cmake paddle macro name
HydrogenSulfate Sep 27, 2024
e3c1ceb
update parallel training with paddle backend
HydrogenSulfate Sep 27, 2024
49ba5a5
use 0-D Tensor as buffer shape
HydrogenSulfate Sep 27, 2024
f1cae59
support DCU(rocm)
HydrogenSulfate Sep 29, 2024
a83fb63
simplify compile code
HydrogenSulfate Sep 29, 2024
f7f64b1
fix code
HydrogenSulfate Oct 9, 2024
d5a313e
remove float() for already supporting 0-D scalar __format__
HydrogenSulfate Oct 14, 2024
299548a
polish flag via paddle.set_flags
HydrogenSulfate Oct 14, 2024
97828b3
restore make_model.py
HydrogenSulfate Oct 14, 2024
d67e27c
Merge branch 'devel' into add_paddle_backend
HydrogenSulfate Oct 14, 2024
87069e5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 14, 2024
d6e3fdb
update document and fix bugs in code
HydrogenSulfate Oct 15, 2024
8004a52
Merge branch 'add_paddle_backend' of https://github.com/HydrogenSulfa…
HydrogenSulfate Oct 15, 2024
8e951cd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 15, 2024
3a0f700
simplify code and only support json and pd
HydrogenSulfate Oct 15, 2024
73c0d17
Merge branch 'add_paddle_backend' of https://github.com/HydrogenSulfa…
HydrogenSulfate Oct 15, 2024
8a59a53
Fix get_generator
HydrogenSulfate Oct 15, 2024
7821997
set default NUM_WORKERS to 0
HydrogenSulfate Oct 15, 2024
ed51258
fix LOCAL_RANK
HydrogenSulfate Oct 15, 2024
13c7f55
update CMAKE and remove build_cc_pd.sh
HydrogenSulfate Oct 17, 2024
e013860
remove paddle.jit.export
HydrogenSulfate Oct 17, 2024
4c4568f
refine code
HydrogenSulfate Oct 17, 2024
7525dac
fix neighborstat bug: gt -> ge
HydrogenSulfate Oct 18, 2024
312a3ef
fix bugs and add unitest of paddle backend
HydrogenSulfate Oct 20, 2024
5e4edd7
fix part of codes
HydrogenSulfate Oct 22, 2024
0146f24
Merge branch 'devel' into add_paddle_backend
HydrogenSulfate Oct 22, 2024
834d512
update document and test_python yaml
HydrogenSulfate Oct 22, 2024
9e90416
remove typeAlias for not available in python3.9
HydrogenSulfate Oct 22, 2024
cfacca3
update repformers.py
HydrogenSulfate Oct 22, 2024
d492397
remove old_impl code and redundant init.py
HydrogenSulfate Oct 24, 2024
de24e27
correct paddlepaddle requirement string
HydrogenSulfate Oct 24, 2024
264286f
update old code
HydrogenSulfate Oct 24, 2024
fd6aff0
update test C++ interface
HydrogenSulfate Oct 28, 2024
6651e87
Merge branch 'devel' into add_paddle_backend
HydrogenSulfate Oct 28, 2024
39ca3b7
update pip index
HydrogenSulfate Oct 28, 2024
2de0d2b
remoe cvt.py
HydrogenSulfate Oct 28, 2024
11d0344
update files
HydrogenSulfate Oct 28, 2024
3650214
upload new files
HydrogenSulfate Oct 28, 2024
754b948
update code
HydrogenSulfate Oct 28, 2024
f007fb4
upload missing files
HydrogenSulfate Oct 28, 2024
da4fb97
add soft link
HydrogenSulfate Oct 28, 2024
f8a4279
fix ci
HydrogenSulfate Oct 28, 2024
1643c6c
refine nlist
HydrogenSulfate Oct 28, 2024
0d0662f
fix req
HydrogenSulfate Oct 28, 2024
8b9ee50
update index-strategy
HydrogenSulfate Oct 28, 2024
d10a3f7
add auto download for paddle_inference.tgz
HydrogenSulfate Oct 28, 2024
f6253cf
skip 3 allclose temporarily
HydrogenSulfate Oct 28, 2024
690dec2
fix mlp
HydrogenSulfate Oct 29, 2024
ec47dd8
Merge branch 'devel' into add_paddle_backend
HydrogenSulfate Oct 29, 2024
03084d2
fix test_nlist.py
HydrogenSulfate Oct 29, 2024
10e4a0d
set device in env.py
HydrogenSulfate Oct 29, 2024
a1adc8a
skip some tests temporarily
HydrogenSulfate Oct 29, 2024
5edc0ae
quite download and decompression
HydrogenSulfate Oct 29, 2024
6d13376
Merge branch 'devel' into add_paddle_backend
HydrogenSulfate Oct 29, 2024
9b1f322
update project.toml
HydrogenSulfate Oct 29, 2024
2b34756
use np.testing.assert_allclose instead of paddle.allclose for more ac…
HydrogenSulfate Oct 29, 2024
dad72c8
use np.testing.assert_allclose instead of paddle.allclose
HydrogenSulfate Oct 29, 2024
cbc9c65
reduce prec from 1e-10 to 1e-9 for test_rot
HydrogenSulfate Oct 29, 2024
afd4746
Merge branch 'devel' into add_paddle_backend
HydrogenSulfate Oct 30, 2024
7b2476f
support LKF optimizer
HydrogenSulfate Oct 31, 2024
26047e9
fix condition block dtype mismatch in jit.save and enable 2 unitest
HydrogenSulfate Oct 31, 2024
d03702a
fix bugs and enable more pd unitests
HydrogenSulfate Nov 1, 2024
c611955
fix the last 2 files
HydrogenSulfate Nov 1, 2024
3b27c49
rename aux to decomp
HydrogenSulfate Nov 1, 2024
541dae6
polish decomp.py
HydrogenSulfate Nov 1, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 27 additions & 27 deletions .pre-commit-config.yaml
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change to this file should be reverted

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I annotate mirrors-prettier and mirrors-bibtex-tidy because that can not be installed in my development env.
image

I guess this file can be restored at the last commit in this PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.

Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ repos:
hooks:
- id: clang-format
exclude: ^(source/3rdparty|source/lib/src/gpu/cudart/.+\.inc|.+\.ipynb$)
# markdown, yaml, CSS, javascript
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v4.0.0-alpha.8
hooks:
- id: prettier
types_or: [markdown, yaml, css]
# workflow files cannot be modified by pre-commit.ci
exclude: ^(source/3rdparty|\.github/workflows|\.clang-format)
# # markdown, yaml, CSS, javascript
# - repo: https://github.com/pre-commit/mirrors-prettier
# rev: v4.0.0-alpha.8
# hooks:
# - id: prettier
# types_or: [markdown, yaml, css]
# # workflow files cannot be modified by pre-commit.ci
# exclude: ^(source/3rdparty|\.github/workflows|\.clang-format)
# Shell
- repo: https://github.com/scop/pre-commit-shfmt
rev: v3.9.0-1
Expand All @@ -75,25 +75,25 @@ repos:
hooks:
- id: cmake-format
#- id: cmake-lint
- repo: https://github.com/njzjz/mirrors-bibtex-tidy
rev: v1.13.0
hooks:
- id: bibtex-tidy
args:
- --curly
- --numeric
- --align=13
- --blank-lines
# disable sort: the order of keys and fields has explict meanings
#- --sort=key
- --duplicates=key,doi,citation,abstract
- --merge=combine
#- --sort-fields
#- --strip-comments
- --trailing-commas
- --encode-urls
- --remove-empty-fields
- --wrap=80
# - repo: https://github.com/njzjz/mirrors-bibtex-tidy
# rev: v1.13.0
# hooks:
# - id: bibtex-tidy
# args:
# - --curly
# - --numeric
# - --align=13
# - --blank-lines
# # disable sort: the order of keys and fields has explict meanings
# #- --sort=key
# - --duplicates=key,doi,citation,abstract
# - --merge=combine
# #- --sort-fields
# #- --strip-comments
# - --trailing-commas
# - --encode-urls
# - --remove-empty-fields
# - --wrap=80
# license header
- repo: https://github.com/Lucas-C/pre-commit-hooks
rev: v1.5.5
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ For more information, check the [documentation](https://deepmd.readthedocs.io/).

### Highlighted features

- **interfaced with multiple backends**, including TensorFlow and PyTorch, the most popular deep learning frameworks, making the training process highly automatic and efficient.
- **interfaced with multiple backends**, including TensorFlow, PyTorch and Paddle, the most popular deep learning frameworks, making the training process highly automatic and efficient.
- **interfaced with high-performance classical MD and quantum (path-integral) MD packages**, including LAMMPS, i-PI, AMBER, CP2K, GROMACS, OpenMM, and ABUCUS.
- **implements the Deep Potential series models**, which have been successfully applied to finite and extended systems, including organic molecules, metals, semiconductors, insulators, etc.
- **implements MPI and GPU supports**, making it highly efficient for high-performance parallel and distributed computing.
Expand Down Expand Up @@ -72,7 +72,7 @@ See [our latest paper](https://doi.org/10.1063/5.0155600) for details of all fea

#### v3

- Multiple backends supported. Add a PyTorch backend.
- Multiple backends supported. Add PyTorch and Paddle backend.
- The DPA-2 model.

## Install and use DeePMD-kit
Expand Down
5 changes: 5 additions & 0 deletions backend/dp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

from scikit_build_core import build as _orig

from .find_paddle import (
find_paddle,
)
from .find_pytorch import (
find_pytorch,
)
Expand Down Expand Up @@ -43,6 +46,7 @@ def get_requires_for_build_wheel(
_orig.get_requires_for_build_wheel(config_settings)
+ find_tensorflow()[1]
+ find_pytorch()[1]
+ find_paddle()[1]
)


Expand All @@ -53,4 +57,5 @@ def get_requires_for_build_editable(
_orig.get_requires_for_build_editable(config_settings)
+ find_tensorflow()[1]
+ find_pytorch()[1]
+ find_paddle()[1]
)
6 changes: 5 additions & 1 deletion backend/dynamic_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
Optional,
)

from .find_paddle import (
get_pd_requirement,
)
from .find_pytorch import (
get_pt_requirement,
)
Expand Down Expand Up @@ -34,7 +37,7 @@ def dynamic_metadata(
settings: Optional[dict[str, object]] = None,
):
assert field in ["optional-dependencies", "entry-points", "scripts"]
_, _, find_libpython_requires, extra_scripts, tf_version, pt_version = (
_, _, find_libpython_requires, extra_scripts, tf_version, pt_version, pd_version = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

⚠️ Potential issue

Unused pd_version Variable Detected

The variable pd_version is unpacked in backend/dynamic_metadata.py at line 39 but is not used within the dynamic_metadata function. This may lead to unused variable warnings and could indicate incomplete implementation.

  • If pd_version is intended to be used:
    • Update the function to utilize pd_version where necessary.
    • For example, incorporate pd_version in the optional-dependencies section:
      elif field == "optional-dependencies":
          optional_dependencies = pyproject["tool"]["deepmd_build_backend"][
              "optional-dependencies"
          ]
          optional_dependencies["lmp"].extend(find_libpython_requires)
          optional_dependencies["ipi"].extend(find_libpython_requires)
          return {
              **optional_dependencies,
              **get_tf_requirement(tf_version),
              **get_pt_requirement(pt_version),
              **get_pd_requirement(pd_version),  # Add this line
          }
  • If pd_version is not needed:
    • Remove pd_version from the unpacking to avoid unused variable warnings:
      _, _, find_libpython_requires, extra_scripts, tf_version, pt_version = (
Analysis chain

Address inconsistency and unused variable

  1. There's an inconsistency between the AI-generated summary and the actual code. The summary mentions a change in the function signature, but this is not reflected in the provided code. Please clarify if there are missing changes or if the summary is incorrect.

  2. The addition of pd_version suggests the introduction of Paddle support, which aligns with the PR objectives. However, pd_version is currently not used in the rest of the function.

To address these issues:

  1. If pd_version is intended to be a parameter, update the function signature:
def dynamic_metadata(
    field: str,
    settings: Optional[Dict[str, object]] = None,
    pd_version: Optional[str] = None,
):
  1. Implement the usage of pd_version in the function body, possibly in the optional-dependencies section:
elif field == "optional-dependencies":
    optional_dependencies = pyproject["tool"]["deepmd_build_backend"][
        "optional-dependencies"
    ]
    optional_dependencies["lmp"].extend(find_libpython_requires)
    optional_dependencies["ipi"].extend(find_libpython_requires)
    return {
        **optional_dependencies,
        **get_tf_requirement(tf_version),
        **get_pt_requirement(pt_version),
        **get_pd_requirement(pd_version),  # Add this line
    }
  1. If pd_version is not needed, remove it from the unpacking to avoid unused variable warnings.

To ensure pd_version is used consistently across the codebase, run the following script:

This will help verify the usage of pd_version and identify any Paddle-related functions that might need to be implemented or updated.

Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for occurrences of pd_version and potential Paddle-related functions

# Search for pd_version
echo "Occurrences of pd_version:"
rg 'pd_version' --type python

# Search for potential Paddle-related functions
echo "\nPotential Paddle-related functions:"
rg 'get_pd_requirement|paddle' --type python

Length of output: 286


Script:

#!/bin/bash
# Description: Check for occurrences of pd_version and potential Paddle-related functions

# Search for pd_version in Python files
echo "Occurrences of pd_version:"
rg 'pd_version' --type py

# Search for potential Paddle-related functions in Python files
echo "\nPotential Paddle-related functions:"
rg 'get_pd_requirement|paddle' --type py

Length of output: 167111

get_argument_from_env()
)
with Path("pyproject.toml").open("rb") as f:
Expand All @@ -55,4 +58,5 @@ def dynamic_metadata(
**optional_dependencies,
**get_tf_requirement(tf_version),
**get_pt_requirement(pt_version),
**get_pd_requirement(pd_version),
}
143 changes: 143 additions & 0 deletions backend/find_paddle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import importlib
import os
import site
from functools import (
lru_cache,
)
from importlib.machinery import (
FileFinder,
)
from importlib.util import (
find_spec,
)
from pathlib import (
Path,
)
from sysconfig import (
get_path,
)
from typing import (
Optional,
Union,
)

from packaging.version import (
Version,
)


@lru_cache
def find_paddle() -> tuple[Optional[str], list[str]]:
"""Find PaddlePadle library.

Tries to find PaddlePadle in the order of:

1. Environment variable `PADDLE_ROOT` if set
2. The current Python environment.
3. user site packages directory if enabled
4. system site packages directory (purelib)

Considering the default PaddlePadle package still uses old CXX11 ABI, we
cannot install it automatically.

Returns
-------
str, optional
PaddlePadle library path if found.
list of str
Paddle requirement if not found. Empty if found.
"""
if os.environ.get("DP_ENABLE_PADDLE", "0") == "0":
return None, []
requires = []
pd_spec = None

if (pd_spec is None or not pd_spec) and os.environ.get("PADDLE_ROOT") is not None:
site_packages = Path(os.environ.get("PADDLE_ROOT")).parent.absolute()
pd_spec = FileFinder(str(site_packages)).find_spec("paddle")

# get paddle spec
# note: isolated build will not work for backend
if pd_spec is None or not pd_spec:
pd_spec = find_spec("paddle")

if not pd_spec and site.ENABLE_USER_SITE:
# first search TF from user site-packages before global site-packages
site_packages = site.getusersitepackages()
if site_packages:
pd_spec = FileFinder(site_packages).find_spec("paddle")

if not pd_spec:
# purelib gets site-packages path
site_packages = get_path("purelib")
if site_packages:
pd_spec = FileFinder(site_packages).find_spec("paddle")

# get install dir from spec
try:
pd_install_dir = pd_spec.submodule_search_locations[0] # type: ignore
# AttributeError if ft_spec is None
# TypeError if submodule_search_locations are None
# IndexError if submodule_search_locations is an empty list
except (AttributeError, TypeError, IndexError):
pd_install_dir = None
requires.extend(get_pd_requirement()["paddle"])
return pd_install_dir, requires


@lru_cache
def get_pd_requirement(pd_version: str = "") -> dict:
"""Get PaddlePadle requirement when Paddle is not installed.

If pd_version is not given and the environment variable `PADDLE_VERSION` is set, use it as the requirement.

Parameters
----------
pd_version : str, optional
Paddle version

Returns
-------
dict
PaddlePadle requirement.
"""
if pd_version is None:
return {"paddle": []}
if pd_version == "":
pd_version = os.environ.get("PADDLE_VERSION", "")
Comment on lines +86 to +104
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Inconsistent handling of pd_version and type annotation

In get_pd_requirement, the type annotation specifies pd_version: str = "", but the code checks if pd_version is None(line 105). Since the default value is"", pd_versionshould not beNoneunless explicitly passed as such. Consider either updating the type annotation toOptional[str]ifNoneis acceptable, or removing theif pd_version is Nonecheck ifpd_version` should always be a string.

Apply this diff to address the issue:

If pd_version can be None, update the type annotation:

-def get_pd_requirement(pd_version: str = "") -> dict:
+def get_pd_requirement(pd_version: Optional[str] = "") -> dict:

Alternatively, if pd_version should always be a string, remove the unnecessary check:

-    if pd_version is None:
-        return {"paddle": []}

Committable suggestion was skipped due to low confidence.


return {
"paddle": [
# uv has different local version behaviors, i.e. `==2.3.1` cannot match `==2.3.1+cpu`
# https://github.com/astral-sh/uv/blob/main/PIP_COMPATIBILITY.md#local-version-identifiers
# luckily, .* (prefix matching) defined in PEP 440 can match any local version
# https://peps.python.org/pep-0440/#version-matching
f"paddle=={Version(pd_version).base_version}.*"
if pd_version != ""
else "paddlepaddle-gpu>=3.0.0b1",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure consistent package naming in requirements

In get_pd_requirement, the package name switches between paddle and paddlepaddle-gpu. This inconsistency may lead to installation issues. Consider standardizing the package name to ensure reliable dependency management.

Apply this diff to standardize the package name:

             f"paddle=={Version(pd_version).base_version}.*"
             if pd_version != ""
-            else "paddlepaddle-gpu>=3.0.0b1",
+            else "paddle>=3.0.0b1",
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
f"paddle=={Version(pd_version).base_version}.*"
if pd_version != ""
else "paddlepaddle-gpu>=3.0.0b1",
f"paddle=={Version(pd_version).base_version}.*"
if pd_version != ""
else "paddle>=3.0.0b1",

],
}


@lru_cache
def get_pd_version(pd_path: Optional[Union[str, Path]]) -> str:
"""Get Paddle version from a Paddle Python library path.

Parameters
----------
pd_path : str or Path
pd Python library path

Returns
-------
str
version
"""
if pd_path is None or pd_path == "":
return ""
version_file = Path(pd_path) / "version" / "__init__.py"
spec = importlib.util.spec_from_file_location("paddle.version", version_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
Comment on lines +130 to +132
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add error handling for potential NoneType exceptions in get_pd_version

The function get_pd_version may raise an exception if spec or spec.loader is None. To prevent this, add checks to handle cases where the version file does not exist or is invalid.

Apply this diff to enhance error handling:

     version_file = Path(pd_path) / "version" / "__init__.py"
+    if not version_file.exists():
+        return ""
     spec = importlib.util.spec_from_file_location("paddle.version", version_file)
+    if spec is None or spec.loader is None:
+        return ""
     module = importlib.util.module_from_spec(spec)
     spec.loader.exec_module(module)
     return module.full_version
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
spec = importlib.util.spec_from_file_location("paddle.version", version_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
version_file = Path(pd_path) / "version" / "__init__.py"
if not version_file.exists():
return ""
spec = importlib.util.spec_from_file_location("paddle.version", version_file)
if spec is None or spec.loader is None:
return ""
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module.full_version

return module.full_version
24 changes: 22 additions & 2 deletions backend/read_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
Version,
)

from .find_paddle import (
find_paddle,
get_pd_version,
)
from .find_pytorch import (
find_pytorch,
get_pt_version,
Expand All @@ -21,7 +25,7 @@


@lru_cache
def get_argument_from_env() -> tuple[str, list, list, dict, str, str]:
def get_argument_from_env() -> tuple[str, list, list, dict, str, str, str]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider refactoring return values into a data structure for better maintainability

As the number of return values from get_argument_from_env increases, it may be beneficial to encapsulate them into a data class or NamedTuple. This improves code readability and maintainability by providing named attributes instead of relying on tuple indices.

You can define a NamedTuple to represent the return values:

from typing import NamedTuple

class BuildArguments(NamedTuple):
    cmake_minimum_required_version: str
    cmake_args: list
    find_libpython_requires: list
    extra_scripts: dict
    tf_version: str
    pt_version: str
    pd_version: str

Update the function signature:

-def get_argument_from_env() -> tuple[str, list, list, dict, str, str, str]:
+def get_argument_from_env() -> BuildArguments:

Update the return statement:

-return (
+return BuildArguments(
    cmake_minimum_required_version,
    cmake_args,
    find_libpython_requires,
    extra_scripts,
    tf_version,
    pt_version,
    pd_version,
)

This approach enhances code clarity when unpacking values:

-def set_scikit_build_env():
+def set_scikit_build_env():
    """Set scikit-build environment variables before executing scikit-build."""
-    cmake_minimum_required_version, cmake_args, _, _, _, _, _ = get_argument_from_env()
+    build_args = get_argument_from_env()
    os.environ["SKBUILD_CMAKE_MINIMUM_VERSION"] = build_args.cmake_minimum_required_version
    os.environ["SKBUILD_CMAKE_ARGS"] = ";".join(build_args.cmake_args)

Also applies to: 47-48, 150-150

"""Get the arguments from environment variables.

The environment variables are assumed to be not changed during the build.
Expand All @@ -40,6 +44,8 @@ def get_argument_from_env() -> tuple[str, list, list, dict, str, str]:
The TensorFlow version.
str
The PyTorch version.
str
The Paddle version.
"""
cmake_args = []
extra_scripts = {}
Expand Down Expand Up @@ -117,6 +123,19 @@ def get_argument_from_env() -> tuple[str, list, list, dict, str, str]:
cmake_args.append("-DENABLE_PYTORCH=OFF")
pt_version = None

if os.environ.get("DP_ENABLE_PADDLE", "0") == "1":
pd_install_dir, _ = find_paddle()
pd_version = get_pd_version(pd_install_dir)
cmake_args.extend(
[
"-DENABLE_PADDLE=ON",
f"-DCMAKE_PREFIX_PATH={pd_install_dir}",
HydrogenSulfate marked this conversation as resolved.
Show resolved Hide resolved
]
)
else:
cmake_args.append("-DENABLE_PADDLE=OFF")
pd_version = None

cmake_args = [
"-DBUILD_PY_IF:BOOL=TRUE",
*cmake_args,
Expand All @@ -128,11 +147,12 @@ def get_argument_from_env() -> tuple[str, list, list, dict, str, str]:
extra_scripts,
tf_version,
pt_version,
pd_version,
)


def set_scikit_build_env():
"""Set scikit-build environment variables before executing scikit-build."""
cmake_minimum_required_version, cmake_args, _, _, _, _ = get_argument_from_env()
cmake_minimum_required_version, cmake_args, _, _, _, _, _ = get_argument_from_env()
os.environ["SKBUILD_CMAKE_MINIMUM_VERSION"] = cmake_minimum_required_version
os.environ["SKBUILD_CMAKE_ARGS"] = ";".join(cmake_args)
Loading
Loading