Skip to content

Commit 9441003

Browse files
Enable Training Pybindings in OSS (#8073)
* Enable Training Pybindings in OSS * changes * changes * changes * changes * lint * unbreak weird apple failure
1 parent 6b58e2e commit 9441003

File tree

3 files changed

+59
-2
lines changed

3 files changed

+59
-2
lines changed

CMakeLists.txt

+36
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,13 @@ cmake_dependent_option(
240240
"NOT EXECUTORCH_BUILD_ARM_BAREMETAL" OFF
241241
)
242242

243+
244+
if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
245+
set(EXECUTORCH_BUILD_EXTENSION_TENSOR ON)
246+
set(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER ON)
247+
set(EXECUTORCH_BUILD_EXTENSION_MODULE ON)
248+
endif()
249+
243250
if(EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT)
244251
set(EXECUTORCH_BUILD_EXTENSION_TENSOR ON)
245252
set(EXECUTORCH_BUILD_KERNELS_CUSTOM ON)
@@ -791,6 +798,35 @@ if(EXECUTORCH_BUILD_PYBIND)
791798
install(TARGETS portable_lib
792799
LIBRARY DESTINATION executorch/extension/pybindings
793800
)
801+
802+
if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
803+
804+
set(_pybind_training_dep_libs
805+
${TORCH_PYTHON_LIBRARY}
806+
etdump
807+
executorch
808+
util
809+
torch
810+
extension_training
811+
)
812+
813+
if(EXECUTORCH_BUILD_XNNPACK)
814+
# need to explicitly specify XNNPACK and microkernels-prod
815+
# here otherwise uses XNNPACK and microkernel-prod symbols from libtorch_cpu
816+
list(APPEND _pybind_training_dep_libs xnnpack_backend XNNPACK microkernels-prod)
817+
endif()
818+
819+
# pybind training
820+
pybind11_add_module(_training_lib SHARED extension/training/pybindings/_training_lib.cpp)
821+
822+
target_include_directories(_training_lib PRIVATE ${TORCH_INCLUDE_DIRS})
823+
target_compile_options(_training_lib PUBLIC ${_pybind_compile_options})
824+
target_link_libraries(_training_lib PRIVATE ${_pybind_training_dep_libs})
825+
826+
install(TARGETS _training_lib
827+
LIBRARY DESTINATION executorch/extension/training/pybindings
828+
)
829+
endif()
794830
endif()
795831

796832
if(EXECUTORCH_BUILD_KERNELS_CUSTOM)

install_executorch.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def clean():
3232
print("Done cleaning build artifacts.")
3333

3434

35-
VALID_PYBINDS = ["coreml", "mps", "xnnpack"]
35+
VALID_PYBINDS = ["coreml", "mps", "xnnpack", "training"]
3636

3737

3838
def main(args):
@@ -78,8 +78,12 @@ def main(args):
7878
raise Exception(
7979
f"Unrecognized pybind argument {pybind_arg}; valid options are: {', '.join(VALID_PYBINDS)}"
8080
)
81+
if pybind_arg == "training":
82+
CMAKE_ARGS += " -DEXECUTORCH_BUILD_EXTENSION_TRAINING=ON"
83+
os.environ["EXECUTORCH_BUILD_TRAINING"] = "ON"
84+
else:
85+
CMAKE_ARGS += f" -DEXECUTORCH_BUILD_{pybind_arg.upper()}=ON"
8186
EXECUTORCH_BUILD_PYBIND = "ON"
82-
CMAKE_ARGS += f" -DEXECUTORCH_BUILD_{pybind_arg.upper()}=ON"
8387

8488
if args.clean:
8589
clean()

setup.py

+17
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ def _is_env_enabled(env_var: str, default: bool = False) -> bool:
8686
def pybindings(cls) -> bool:
8787
return cls._is_env_enabled("EXECUTORCH_BUILD_PYBIND", default=False)
8888

89+
@classmethod
90+
def training(cls) -> bool:
91+
return cls._is_env_enabled("EXECUTORCH_BUILD_TRAINING", default=False)
92+
8993
@classmethod
9094
def llama_custom_ops(cls) -> bool:
9195
return cls._is_env_enabled("EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT", default=True)
@@ -575,6 +579,11 @@ def run(self):
575579
"-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON", # add quantized ops to pybindings.
576580
"-DEXECUTORCH_BUILD_KERNELS_QUANTIZED_AOT=ON",
577581
]
582+
if ShouldBuild.training():
583+
cmake_args += [
584+
"-DEXECUTORCH_BUILD_EXTENSION_TRAINING=ON",
585+
]
586+
build_args += ["--target", "_training_lib"]
578587
build_args += ["--target", "portable_lib"]
579588
# To link backends into the portable_lib target, callers should
580589
# add entries like `-DEXECUTORCH_BUILD_XNNPACK=ON` to the CMAKE_ARGS
@@ -677,6 +686,14 @@ def get_ext_modules() -> List[Extension]:
677686
"_portable_lib.*", "executorch.extension.pybindings._portable_lib"
678687
)
679688
)
689+
if ShouldBuild.training():
690+
ext_modules.append(
691+
# Install the prebuilt pybindings extension wrapper for training
692+
BuiltExtension(
693+
"_training_lib.*",
694+
"executorch.extension.training.pybindings._training_lib",
695+
)
696+
)
680697
if ShouldBuild.llama_custom_ops():
681698
ext_modules.append(
682699
BuiltFile(

0 commit comments

Comments
 (0)