Skip to content

Commit 6f423d0

Browse files
authored
update pin (#8677)
1 parent 42edbe1 commit 6f423d0

File tree

7 files changed

+28
-10
lines changed

7 files changed

+28
-10
lines changed

WORKSPACE

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ new_local_repository(
4646

4747
# To build PyTorch/XLA with OpenXLA to a new revision, update following xla_hash to
4848
# the openxla git commit hash.
49-
xla_hash = '6e91ff19dad528ab7d2025a9bb46150618a3bc7d'
49+
xla_hash = '52d5ccaf00fdbc32956c457eae415c09f56f0208'
5050

5151
http_archive(
5252
name = "xla",
@@ -57,6 +57,7 @@ http_archive(
5757
patch_tool = "patch",
5858
patches = [
5959
"//openxla_patches:gpu_race_condition.diff",
60+
"//openxla_patches:count_down.diff",
6061
],
6162
strip_prefix = "xla-" + xla_hash,
6263
urls = [

openxla_patches/count_down.diff

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
diff --git a/xla/backends/cpu/runtime/convolution_thunk_internal.h b/xla/backends/cpu/runtime/convolution_thunk_internal.h
2+
index 84fed6bb78..9835f12e4e 100644
3+
--- a/xla/backends/cpu/runtime/convolution_thunk_internal.h
4+
+++ b/xla/backends/cpu/runtime/convolution_thunk_internal.h
5+
@@ -342,7 +342,8 @@ void EigenGenericConv2D(
6+
Eigen::Index start = task_index * task_size;
7+
Eigen::Index end = std::min(start + task_size, feature_group_count);
8+
for (Eigen::Index i = start; i < end; ++i) {
9+
- auto on_done = [count_down]() mutable { count_down.CountDown(); };
10+
+ // auto on_done = [count_down]() mutable { count_down.CountDown(); };
11+
+ auto on_done = [count_down]() mutable { const_cast<decltype(count_down)*>(&count_down)->CountDown(); };
12+
auto [output, convolved] = convolve_group(i);
13+
output.device(device, std::move(on_done)) = convolved;
14+
}

setup.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,10 @@
6666

6767
USE_NIGHTLY = True # whether to use nightly or stable libtpu and jax
6868

69-
_date = '20250131'
69+
_date = '20250210'
7070

7171
# Note: jax/jaxlib 20250115 build will fail. Check https://github.com/pytorch/xla/pull/8621#issuecomment-2616564634 for more details.
72-
_libtpu_version = '0.0.9'
72+
_libtpu_version = '0.0.10'
7373
_jax_version = '0.5.1'
7474
_jaxlib_version = '0.5.1'
7575

@@ -332,9 +332,6 @@ def run(self):
332332
'tpu': [
333333
f'libtpu=={_libtpu_version}',
334334
'tpu-info',
335-
# This special version removes `libtpu.so` from any `libtpu-nightly` installations,
336-
# since we have migrated to using the `libtpu.so` from the `libtpu` package.
337-
"libtpu-nightly==0.1.dev20241010+nightly.cleanup"
338335
],
339336
# pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
340337
'pallas': [f'jaxlib=={_jaxlib_version}', f'jax=={_jax_version}'],

test/test_core_aten_ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -993,7 +993,10 @@ def test_aten_convolution_1(self):
993993
1,
994994
)
995995
kwargs = dict()
996-
run_export_and_compare(self, torch.ops.aten.convolution, args, kwargs)
996+
# With xla pin to 52d5ccaf00fdbc32956c457eae415c09f56f0208
997+
# The rtol needs to be raise to 1e-3 on CPU.
998+
run_export_and_compare(
999+
self, torch.ops.aten.convolution, args, kwargs, rtol=1e-3)
9971000

9981001
def test_aten_convolution_2(self):
9991002
args = (

torch_xla/csrc/dl_convertor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) {
329329
device->client()->CreateViewOfDeviceBuffer(
330330
static_cast<char*>(dlmt->dl_tensor.data) +
331331
dlmt->dl_tensor.byte_offset,
332-
shape, device, on_delete_callback);
332+
shape, *device->default_memory_space(), on_delete_callback);
333333
XLA_CHECK_OK(pjrt_buffer.status()) << "Failed to create a pjrt buffer.";
334334
XLA_CHECK(pjrt_buffer.value() != nullptr) << "pjrt buffer is null.";
335335

torch_xla/csrc/runtime/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ cc_library(
165165
":tf_logging",
166166
"@tsl//tsl/platform:stacktrace",
167167
"@tsl//tsl/platform:statusor",
168+
"@tsl//tsl/platform:macros",
168169
],
169170
)
170171

torch_xla/csrc/runtime/pjrt_computation_client.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,9 @@ std::vector<ComputationClient::DataPtr> PjRtComputationClient::TransferToDevice(
274274
tensor->dimensions(), tensor->byte_strides(),
275275
xla::PjRtClient::HostBufferSemantics::
276276
kImmutableUntilTransferCompletes,
277-
[tensor]() { /* frees tensor */ }, pjrt_device)
277+
[tensor]() { /* frees tensor */ },
278+
*pjrt_device->default_memory_space(),
279+
/*device_layout=*/nullptr)
278280
.value());
279281

280282
ComputationClient::DataPtr data =
@@ -321,7 +323,7 @@ ComputationClient::DataPtr PjRtComputationClient::CopyToDevice(
321323

322324
// Returns error if the buffer is already on `dst_device`.
323325
absl::StatusOr<std::unique_ptr<xla::PjRtBuffer>> status_or =
324-
pjrt_data->buffer->CopyToDevice(dst_device);
326+
pjrt_data->buffer->CopyToMemorySpace(*dst_device->default_memory_space());
325327
if (!status_or.ok()) {
326328
return data;
327329
}

0 commit comments

Comments
 (0)