Skip to content

Commit f83237b

Browse files
committed
wip: use torch from a wheel
not recorded: external step of extracting a wheel * unzip wheel * create dist dir with wheel see extract.sh a couple folders up
1 parent 6ed3ca1 commit f83237b

File tree

5 files changed

+187
-33
lines changed

5 files changed

+187
-33
lines changed

WORKSPACE

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,29 @@ python_configure(
3535
################################ PyTorch Setup ################################
3636

3737
load("//bazel:dependencies.bzl", "PYTORCH_LOCAL_DIR")
38+
load("//bazel:torch_repo.bzl", "torch_repo")
3839

39-
new_local_repository(
40+
torch_repo(
4041
name = "torch",
41-
build_file = "//bazel:torch.BUILD",
42-
path = PYTORCH_LOCAL_DIR,
42+
path = "/usr/local/google/home/rlevasseur/p/torch_repo/torch-2.8.0.dev20250609+cpu-cp311-cp311-manylinux_2_28_x86_64.whl",
43+
urls = [
44+
"https://download.pytorch.org/whl/nightly/cpu/torch-2.8.0.dev20250609%2Bcpu-cp311-cp311-manylinux_2_28_x86_64.whl",
45+
],
4346
)
4447

48+
##new_local_repository(
49+
## name = "torch",
50+
## ##build_file = "//bazel:torch.BUILD",
51+
## ##path = PYTORCH_LOCAL_DIR,
52+
## build_file = "//bazel:torchnew.BUILD",
53+
## path = "/usr/local/google/home/rlevasseur/p/torch_repo/torch-2.8.0.dev20250609+cpu-cp311-cp311-manylinux_2_28_x86_64",
54+
##)
55+
4556
############################# OpenXLA Setup ###############################
4657

4758
# To build PyTorch/XLA with a new revison of OpenXLA, update the xla_hash to
4859
# the openxla git commit hash and note the date of the commit.
49-
xla_hash = 'd4576615b3bd3644567da60202faf19b485b52f9' # Committed on 2025-06-05.
60+
xla_hash = "d4576615b3bd3644567da60202faf19b485b52f9" # Committed on 2025-06-05.
5061

5162
http_archive(
5263
name = "xla",
@@ -66,8 +77,6 @@ http_archive(
6677
],
6778
)
6879

69-
70-
7180
# For development, one often wants to make changes to the OpenXLA repository as well
7281
# as the PyTorch/XLA repository. You can override the pinned repository above with a
7382
# local checkout by either:
@@ -89,14 +98,14 @@ python_init_rules()
8998
load("@xla//third_party/py:python_init_repositories.bzl", "python_init_repositories")
9099

91100
python_init_repositories(
101+
default_python_version = "system",
102+
local_wheel_workspaces = ["@torch//:WORKSPACE"],
92103
requirements = {
93104
"3.8": "//:requirements_lock_3_8.txt",
94105
"3.9": "//:requirements_lock_3_9.txt",
95106
"3.10": "//:requirements_lock_3_10.txt",
96107
"3.11": "//:requirements_lock_3_11.txt",
97108
},
98-
local_wheel_workspaces = ["@torch//:WORKSPACE"],
99-
default_python_version = "system",
100109
)
101110

102111
load("@xla//third_party/py:python_init_toolchains.bzl", "python_init_toolchains")
@@ -111,8 +120,6 @@ load("@pypi//:requirements.bzl", "install_deps")
111120

112121
install_deps()
113122

114-
115-
116123
# Initialize OpenXLA's external dependencies.
117124
load("@xla//:workspace4.bzl", "xla_workspace4")
118125

@@ -134,7 +141,6 @@ load("@xla//:workspace0.bzl", "xla_workspace0")
134141

135142
xla_workspace0()
136143

137-
138144
load(
139145
"@xla//third_party/gpus:cuda_configure.bzl",
140146
"cuda_configure",

bazel/torch.BUILD

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,36 +23,44 @@ cc_library(
2323
filegroup(
2424
name = "torchgen_deps",
2525
srcs = [
26-
"aten/src/ATen/native/native_functions.yaml",
27-
"aten/src/ATen/native/tags.yaml",
28-
"aten/src/ATen/native/ts_native_functions.yaml",
29-
"aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp",
30-
"aten/src/ATen/templates/DispatchKeyNativeFunctions.h",
31-
"aten/src/ATen/templates/LazyIr.h",
32-
"aten/src/ATen/templates/LazyNonNativeIr.h",
33-
"aten/src/ATen/templates/RegisterDispatchDefinitions.ini",
34-
"aten/src/ATen/templates/RegisterDispatchKey.cpp",
35-
"torch/csrc/lazy/core/shape_inference.h",
36-
"torch/csrc/lazy/ts_backend/ts_native_functions.cpp",
26+
# torchgen/packaged/ instead of aten/src
27+
"torchgen/packaged/ATen/native/native_functions.yaml",
28+
"torchgen/packaged/ATen/native/tags.yaml",
29+
##"torchgen/packaged/ATen/native/ts_native_functions.yaml",
30+
"torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.cpp",
31+
"torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.h",
32+
"torchgen/packaged/ATen/templates/LazyIr.h",
33+
"torchgen/packaged/ATen/templates/LazyNonNativeIr.h",
34+
"torchgen/packaged/ATen/templates/RegisterDispatchDefinitions.ini",
35+
"torchgen/packaged/ATen/templates/RegisterDispatchKey.cpp",
36+
# Add torch/include prefix
37+
"torch/include/torch/csrc/lazy/core/shape_inference.h",
38+
##"torch/csrc/lazy/ts_backend/ts_native_functions.cpp",
3739
],
3840
)
3941

40-
cc_import(
42+
# Changed to cc_library from cc_import
43+
44+
cc_library(
4145
name = "libtorch",
42-
shared_library = "build/lib/libtorch.so",
46+
srcs = ["torch/lib/libtorch.so"],
4347
)
4448

45-
cc_import(
49+
cc_library(
4650
name = "libtorch_cpu",
47-
shared_library = "build/lib/libtorch_cpu.so",
51+
srcs = ["torch/lib/libtorch_cpu.so"],
4852
)
4953

50-
cc_import(
54+
cc_library(
5155
name = "libtorch_python",
52-
shared_library = "build/lib/libtorch_python.so",
56+
srcs = [
57+
# Added this
58+
"torch/lib/libshm.so",
59+
"torch/lib/libtorch_python.so",
60+
],
5361
)
5462

55-
cc_import(
63+
cc_library(
5664
name = "libc10",
57-
shared_library = "build/lib/libc10.so",
65+
srcs = ["torch/lib/libc10.so"],
5866
)

bazel/torch_repo.bzl

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#load("@bazel_skylib//lib:paths.bzl", "paths")
2+
3+
_BUILD_TEMPLATE = """
4+
5+
load("@//bazel:torch_whl_targets.bzl", "define_torch_whl_targets")
6+
7+
package(
8+
default_visibility = [
9+
"//visibility:public",
10+
],
11+
)
12+
13+
define_torch_whl_targets()
14+
"""
15+
16+
def _basename(path):
17+
_, _, basename = path.rpartition("/")
18+
return basename
19+
20+
def _urldecode(s):
21+
# Starlark doesn't have any URL decode functions, so just approximate
22+
# one with the cases we see.
23+
return s.replace("%2B", "+")
24+
25+
def _torch_repo_impl(rctx):
26+
rctx.file("BUILD.bazel", _BUILD_TEMPLATE)
27+
28+
if rctx.attr.path:
29+
path = rctx.attr.path
30+
whl_basename = _basename(path)
31+
dist_name = "dist/{}".format(whl_basename)
32+
rctx.symlink(path, dist_name)
33+
34+
rctx.symlink(path, "torch.zip")
35+
rctx.extract("torch.zip")
36+
rctx.delete("torch.zip")
37+
elif rctx.attr.urls:
38+
urls = rctx.attr.urls
39+
extract_file = "torch.zip"
40+
whl_basename = _urldecode(_basename(urls[0].rpartition("/")[2]))
41+
dist_name = "dist/{}".format(whl_basename)
42+
result = rctx.download(
43+
url = rctx.attr.urls,
44+
output = dist_name,
45+
sha256 = rctx.attr.sha256,
46+
integrity = rctx.attr.sha256,
47+
)
48+
if not result.success:
49+
fail("Failed to download: {}", rctx.attr.urls)
50+
51+
rctx.symlink(dist_name, "torch.zip")
52+
rctx.extract("torch.zip")
53+
rctx.delete("torch.zip")
54+
55+
elif rctx.attr.source_path:
56+
pass
57+
# symlink dist dir
58+
# file structure is different. Need BUILD file
59+
# to look different and torchgen code to be different.
60+
61+
torch_repo = repository_rule(
62+
implementation = _torch_repo_impl,
63+
attrs = {
64+
# Local source checkout with built whl
65+
"source_path": attr.string(),
66+
# Local prebuilt wheel
67+
"path": attr.string(),
68+
# remote whl
69+
"urls": attr.string_list(),
70+
"sha256": attr.string(),
71+
},
72+
)

bazel/torch_whl_targets.bzl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
cc_library = native.cc_library
2+
filegroup = native.filegroup
3+
glob = native.glob
4+
5+
def define_torch_whl_targets():
6+
cc_library(
7+
name = "headers",
8+
hdrs = glob(
9+
["torch/include/**/*.h"],
10+
["torch/include/google/protobuf/**/*.h"],
11+
),
12+
strip_include_prefix = "torch/include",
13+
)
14+
15+
# Runtime headers, for importing <torch/torch.h>.
16+
cc_library(
17+
name = "runtime_headers",
18+
hdrs = glob(["torch/include/torch/csrc/api/include/**/*.h"]),
19+
strip_include_prefix = "torch/include/torch/csrc/api/include",
20+
)
21+
22+
filegroup(
23+
name = "torchgen_deps",
24+
srcs = [
25+
# torchgen/packaged/ instead of aten/src
26+
"torchgen/packaged/ATen/native/native_functions.yaml",
27+
"torchgen/packaged/ATen/native/tags.yaml",
28+
##"torchgen/packaged/ATen/native/ts_native_functions.yaml",
29+
"torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.cpp",
30+
"torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.h",
31+
"torchgen/packaged/ATen/templates/LazyIr.h",
32+
"torchgen/packaged/ATen/templates/LazyNonNativeIr.h",
33+
"torchgen/packaged/ATen/templates/RegisterDispatchDefinitions.ini",
34+
"torchgen/packaged/ATen/templates/RegisterDispatchKey.cpp",
35+
# Add torch/include prefix
36+
"torch/include/torch/csrc/lazy/core/shape_inference.h",
37+
##"torch/csrc/lazy/ts_backend/ts_native_functions.cpp",
38+
],
39+
)
40+
41+
# Changed to cc_library from cc_import
42+
43+
cc_library(
44+
name = "libtorch",
45+
srcs = ["torch/lib/libtorch.so"],
46+
)
47+
48+
cc_library(
49+
name = "libtorch_cpu",
50+
srcs = ["torch/lib/libtorch_cpu.so"],
51+
)
52+
53+
cc_library(
54+
name = "libtorch_python",
55+
srcs = [
56+
# Added this
57+
"torch/lib/libshm.so",
58+
"torch/lib/libtorch_python.so",
59+
],
60+
)
61+
62+
cc_library(
63+
name = "libc10",
64+
srcs = ["torch/lib/libc10.so"],
65+
)

codegen/lazy_tensor_generator.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717

1818
xla_root = sys.argv[1]
1919
torch_root = os.path.join(xla_root, "torch")
20-
aten_path = os.path.join(torch_root, "aten", "src", "ATen")
21-
shape_inference_hdr = os.path.join(torch_root, "torch", "csrc", "lazy", "core",
22-
"shape_inference.h")
20+
##aten_path = os.path.join(torch_root, "aten", "src", "ATen")
21+
aten_path = os.path.join(torch_root, "torchgen", "packaged", "ATen")
22+
##shape_inference_hdr = os.path.join(torch_root, "torch", "csrc", "lazy", "core",
23+
## "shape_inference.h")
24+
shape_inference_hdr = os.path.join(torch_root, "torch", "include",
25+
"torch", "csrc", "lazy", "core", "shape_inference.h")
2326
impl_path = os.path.join(xla_root, "__main__",
2427
"torch_xla/csrc/aten_xla_type.cpp")
2528
source_yaml = sys.argv[2]

0 commit comments

Comments
 (0)