From 00652e6e6a0ffb9c60a21ac91680bc1ac5528bc4 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 2 May 2024 11:25:18 -0700 Subject: [PATCH] [xla:cpu] NFC: Remove deprecated XLA:CPU mlir based codegen part #3 PiperOrigin-RevId: 630125494 --- tensorflow/compiler/aot/tfcompile.bzl | 2 - third_party/xla/xla/service/cpu/BUILD | 7 - third_party/xla/xla/service/cpu/runtime/BUILD | 201 ---------- .../xla/service/cpu/runtime/collectives.cc | 368 ------------------ .../xla/xla/service/cpu/runtime/collectives.h | 29 -- .../xla/service/cpu/runtime/convolution.cc | 219 ----------- .../xla/xla/service/cpu/runtime/convolution.h | 47 --- .../service/cpu/runtime/convolution_call.cc | 83 ---- .../service/cpu/runtime/convolution_call.h | 28 -- .../service/cpu/runtime/convolution_ffi.cc | 106 ----- .../xla/service/cpu/runtime/convolution_ffi.h | 22 -- .../xla/service/cpu/runtime/custom_call.cc | 177 --------- .../xla/xla/service/cpu/runtime/custom_call.h | 29 -- .../xla/xla/service/cpu/runtime/fft_call.cc | 114 ------ .../xla/xla/service/cpu/runtime/fft_call.h | 28 -- .../xla/xla/service/cpu/runtime/retain.cc | 38 -- .../xla/xla/service/cpu/runtime/rng.cc | 201 ---------- third_party/xla/xla/service/cpu/runtime/rng.h | 46 --- .../xla/xla/service/cpu/runtime/rng_call.cc | 76 ---- .../xla/xla/service/cpu/runtime/rng_call.h | 29 -- .../xla/xla/service/cpu/runtime/rng_ffi.cc | 107 ----- .../xla/xla/service/cpu/runtime/rng_ffi.h | 24 -- .../xla/xla/service/cpu/runtime/xfeed.cc | 189 --------- .../xla/xla/service/cpu/runtime/xfeed.h | 29 -- 24 files changed, 2199 deletions(-) delete mode 100644 third_party/xla/xla/service/cpu/runtime/BUILD delete mode 100644 third_party/xla/xla/service/cpu/runtime/collectives.cc delete mode 100644 third_party/xla/xla/service/cpu/runtime/collectives.h delete mode 100644 third_party/xla/xla/service/cpu/runtime/convolution.cc delete mode 100644 third_party/xla/xla/service/cpu/runtime/convolution.h delete mode 100644 third_party/xla/xla/service/cpu/runtime/convolution_call.cc delete mode 100644 third_party/xla/xla/service/cpu/runtime/convolution_call.h delete mode 100644 third_party/xla/xla/service/cpu/runtime/convolution_ffi.cc delete mode 100644 third_party/xla/xla/service/cpu/runtime/convolution_ffi.h delete mode 100644 third_party/xla/xla/service/cpu/runtime/custom_call.cc delete mode 100644 third_party/xla/xla/service/cpu/runtime/custom_call.h delete mode 100644 third_party/xla/xla/service/cpu/runtime/fft_call.cc delete mode 100644 third_party/xla/xla/service/cpu/runtime/fft_call.h delete mode 100644 third_party/xla/xla/service/cpu/runtime/retain.cc delete mode 100644 third_party/xla/xla/service/cpu/runtime/rng.cc delete mode 100644 third_party/xla/xla/service/cpu/runtime/rng.h delete mode 100644 third_party/xla/xla/service/cpu/runtime/rng_call.cc delete mode 100644 third_party/xla/xla/service/cpu/runtime/rng_call.h delete mode 100644 third_party/xla/xla/service/cpu/runtime/rng_ffi.cc delete mode 100644 third_party/xla/xla/service/cpu/runtime/rng_ffi.h delete mode 100644 third_party/xla/xla/service/cpu/runtime/xfeed.cc delete mode 100644 third_party/xla/xla/service/cpu/runtime/xfeed.h diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index a543aae5b92997..99c8541c55488c 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -319,8 +319,6 @@ def _tf_library( ] or []) + (include_standard_runtime_deps and [ # TODO(cwhipkey): only depend on kernel code that the model actually # needed. - "@local_xla//xla/service/cpu/runtime:convolution_ffi", - "@local_xla//xla/service/cpu/runtime:rng_ffi", "@local_xla//xla/service/cpu:runtime_conv2d", "@local_xla//xla/service/cpu:runtime_custom_call_status", "@local_xla//xla/service/cpu:runtime_key_value_sort", diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index e7073aad05263a..c618b9d3077bef 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -339,13 +339,6 @@ cc_library( "//xla/service:while_loop_invariant_code_motion", "//xla/service:while_loop_simplifier", "//xla/service:zero_sized_hlo_elimination", - "//xla/service/cpu/runtime:collectives", - "//xla/service/cpu/runtime:convolution_call", - "//xla/service/cpu/runtime:custom_call", - "//xla/service/cpu/runtime:fft_call", - "//xla/service/cpu/runtime:retain", - "//xla/service/cpu/runtime:rng_call", - "//xla/service/cpu/runtime:xfeed", "//xla/service/llvm_ir:llvm_command_line_options", "//xla/service/llvm_ir:llvm_util", "//xla/service/spmd:stateful_rng_spmd_partitioner", diff --git a/third_party/xla/xla/service/cpu/runtime/BUILD b/third_party/xla/xla/service/cpu/runtime/BUILD deleted file mode 100644 index 1169fdaa39d394..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/BUILD +++ /dev/null @@ -1,201 +0,0 @@ -load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [":friends"], - licenses = ["notice"], -) - -package_group( - name = "friends", - includes = [ - "//xla:friends", - ], -) - -cc_library( - name = "retain", - srcs = ["retain.cc"], - visibility = ["//visibility:public"], - alwayslink = 1, -) - -cc_library( - name = "collectives", - srcs = ["collectives.cc"], - hdrs = ["collectives.h"], - deps = [ - "//xla:executable_run_options", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/runtime:memref_view", - "//xla/service/cpu:cpu_runtime", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:Support", - ], -) - -cc_library( - name = "convolution", - srcs = ["convolution.cc"], - hdrs = ["convolution.h"], - deps = [ - "//xla:executable_run_options", - "//xla:xla_data_proto_cc", - "//xla/runtime:memref_view", - "//xla/service/cpu:runtime_conv2d", - "//xla/service/cpu:runtime_conv3d", - "@com_google_absl//absl/status", - "@com_google_absl//absl/types:span", - "@eigen_archive//:eigen3", - ], -) - -cc_library( - name = "convolution_ffi", - srcs = ["convolution_ffi.cc"], - hdrs = ["convolution_ffi.h"], - visibility = ["//visibility:public"], - deps = [ - ":convolution", - "//xla:xla_data_proto_cc", - "//xla/runtime:aot_ffi", - "//xla/runtime:aot_ffi_execution_context", - "//xla/runtime:memref_view", - "//xla/runtime/ffi:ffi_api", - "//xla/runtime/ffi:ffi_c_api_hdrs", - "@com_google_absl//absl/status", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "convolution_call", - srcs = ["convolution_call.cc"], - hdrs = ["convolution_call.h"], - deps = [ - ":convolution", - "//xla:executable_run_options", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/runtime:memref_view", - "@com_google_absl//absl/types:span", - "@llvm-project//mlir:Support", - ], -) - -cc_library( - name = "custom_call", - srcs = ["custom_call.cc"], - hdrs = ["custom_call.h"], - deps = [ - "//xla:shape_util", - "//xla:xla_proto_cc", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/runtime:memref_view", - "//xla/service:custom_call_status_internal", - "//xla/service:custom_call_status_public_headers", - "//xla/service:custom_call_target_registry", - "//xla/service:hlo_proto_cc", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:Support", - ], -) - -cc_library( - name = "fft_call", - srcs = ["fft_call.cc"], - hdrs = ["fft_call.h"], - deps = [ - "//xla:executable_run_options", - "//xla:xla_data_proto_cc", - "//xla:xla_proto_cc", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/runtime:memref_view", - "//xla/service:hlo_proto_cc", - "//xla/service/cpu:runtime_fft", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@llvm-project//mlir:Support", - ], -) - -cc_library( - name = "xfeed", - srcs = ["xfeed.cc"], - hdrs = ["xfeed.h"], - deps = [ - "//xla:executable_run_options", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/runtime:memref_view", - "//xla/service/cpu:cpu_runtime", - "@com_google_absl//absl/status", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:Support", - ], -) - -cc_library( - name = "rng", - srcs = ["rng.cc"], - hdrs = ["rng.h"], - deps = [ - "//xla:executable_run_options", - "//xla:xla_data_proto_cc", - "//xla/runtime:memref_view", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "rng_call", - srcs = ["rng_call.cc"], - hdrs = ["rng_call.h"], - deps = [ - ":rng", - "//xla:executable_run_options", - "//xla/runtime:custom_call", - "//xla/runtime:custom_call_registry", - "//xla/runtime:executable", - "//xla/runtime:memref_view", - "@llvm-project//mlir:Support", - ], -) - -cc_library( - name = "rng_ffi", - srcs = ["rng_ffi.cc"], - hdrs = ["rng_ffi.h"], - visibility = ["//visibility:public"], - deps = [ - ":rng", - "//xla:xla_data_proto_cc", - "//xla/runtime:aot_ffi", - "//xla/runtime:aot_ffi_execution_context", - "//xla/runtime:memref_view", - "//xla/runtime/ffi:ffi_api", - "//xla/runtime/ffi:ffi_c_api_hdrs", - "@com_google_absl//absl/status", - ], -) diff --git a/third_party/xla/xla/service/cpu/runtime/collectives.cc b/third_party/xla/xla/service/cpu/runtime/collectives.cc deleted file mode 100644 index 6034cc600245b2..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/collectives.cc +++ /dev/null @@ -1,368 +0,0 @@ -// Copyright 2022 The OpenXLA Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "xla/service/cpu/runtime/collectives.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "xla/executable_run_options.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/runtime/executable.h" -#include "xla/runtime/memref_view.h" -#include "xla/service/cpu/cpu_runtime.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/xla_data.pb.h" - -namespace xla { -namespace cpu { - -using mlir::succeeded; - -using ::xla::runtime::CustomCall; -using ::xla::runtime::Executable; -using ::xla::runtime::MemrefView; - -// Disable all CustomCall checks in optimized build. -static constexpr CustomCall::RuntimeChecks RuntimeChecks() { -#if defined(NDEBUG) - return CustomCall::RuntimeChecks::kNone; -#else - return CustomCall::RuntimeChecks::kDefault; -#endif -} - -static std::string ReplicaGroupsToString( - CustomCall::TensorRef replica_groups) { - if (replica_groups.shape[0] == 0) { - return "{}"; - } - std::string result; - - const auto& shape = replica_groups.shape; - size_t stride = replica_groups.data.size() / shape[0]; - - absl::StrAppend(&result, "{"); - for (size_t i = 0; i < replica_groups.data.size(); i += stride) { - if (i > 0) { - absl::StrAppend(&result, ", "); - } - - auto start = replica_groups.data.begin() + i; - llvm::ArrayRef inner_data(start, start + stride); - - absl::StrAppend(&result, "{"); - absl::StrAppend( - &result, - // The replica groups can have different sizes. Smaller groups are - // padded with -1. - absl::StrJoin(llvm::make_filter_range( - inner_data, [](int64_t id) { return id >= 0; }), - ", ")); - absl::StrAppend(&result, "}"); - } - absl::StrAppend(&result, "}"); - - return result; -} - -static std::string SourceTargetPairsToString( - CustomCall::TensorRef source_target_pairs) { - std::string result; - for (size_t i = 0; i < source_target_pairs.data.size(); i += 2) { - if (i > 0) { - absl::StrAppend(&result, ","); - } - absl::StrAppend(&result, source_target_pairs.data[i], "=", - source_target_pairs.data[i + 1]); - } - return result; -} - -// -------------------------------------------------------------------------- // - -namespace { -struct XlaPartitionId { - absl::StatusOr operator()( - const ExecutableRunOptions* run_options) const; - static XlaPartitionId Handler() { return XlaPartitionId(); } -}; -} // namespace - -absl::StatusOr XlaPartitionId::operator()( - const ExecutableRunOptions* run_options) const { - int32_t result; - __xla_cpu_runtime_PartitionId(run_options, &result); - return result; -} - -static bool PartitionId(xla::runtime::ExecutionContext* ctx, void** args, - void** attrs, void** rets) { - static auto* handler = CustomCall::Bind("xla.cpu.partition_id") - .Ret() - .UserData() - .To(XlaPartitionId::Handler()) - .release(); - return succeeded(Executable::Call(ctx, *handler, args, attrs, rets)); -} - -// -------------------------------------------------------------------------- // - -namespace { -struct XlaReplicaId { - absl::StatusOr operator()( - const ExecutableRunOptions* run_options) const; - static XlaReplicaId Handler() { return XlaReplicaId(); } -}; -} // namespace - -absl::StatusOr XlaReplicaId::operator()( - const ExecutableRunOptions* run_options) const { - int32_t result; - __xla_cpu_runtime_ReplicaId(run_options, &result); - return result; -} - -static bool ReplicaId(xla::runtime::ExecutionContext* ctx, void** args, - void** attrs, void** rets) { - static auto* handler = CustomCall::Bind("xla.cpu.replica_id") - .Ret() - .UserData() - .To(XlaReplicaId::Handler()) - .release(); - return succeeded(Executable::Call(ctx, *handler, args, attrs, rets)); -} - -// -------------------------------------------------------------------------- // - -namespace { -struct XlaAllReduce { - absl::Status operator()(const ExecutableRunOptions* run_options, - CustomCall::RemainingArgs buffers, - CustomCall::TensorRef replica_groups, - int64_t channel_id, int32_t use_global_device_ids, - int64_t op_id, int32_t reduction_kind) const; - static XlaAllReduce Handler() { return XlaAllReduce(); } -}; -} // namespace - -absl::Status XlaAllReduce::operator()( - const ExecutableRunOptions* run_options, CustomCall::RemainingArgs buffers, - CustomCall::TensorRef replica_groups, int64_t channel_id, - int32_t use_global_device_ids, int64_t op_id, - int32_t reduction_kind) const { - if (replica_groups.shape.size() != 2) { - return absl::InvalidArgumentError("replica_groups must be a 2d tensor."); - } - - if (buffers.size() % 2) { - return absl::InvalidArgumentError( - "number of input buffers and output buffers must be equal."); - } - - std::string replica_groups_str = ReplicaGroupsToString(replica_groups); - int64_t num_buffers = static_cast(buffers.size()) / 2; - - llvm::SmallVector input_buffers, output_buffers; - ShapeProto shape; - for (int i = 0; i < num_buffers; ++i) { - auto input = buffers.get(i); - auto output = buffers.get(i + num_buffers); - if (!succeeded(input) || !succeeded(output)) { - return absl::InvalidArgumentError("all arguments must be memrefs."); - } - - *shape.add_tuple_shapes() = - ShapeUtil::MakeShapeWithDescendingLayout(input->dtype, input->sizes) - .ToProto(); - input_buffers.push_back(input->data); - output_buffers.push_back(output->data); - } - std::string shape_str = - (shape.tuple_shapes().size() == 1 ? shape.tuple_shapes(0) : shape) - .SerializeAsString(); - - __xla_cpu_runtime_AllReduce( - run_options, replica_groups_str.c_str(), - static_cast(replica_groups_str.size()), - static_cast(channel_id), use_global_device_ids, op_id, - reduction_kind, shape_str.c_str(), static_cast(shape_str.size()), - static_cast(num_buffers), input_buffers.data(), - output_buffers.data()); - - return absl::OkStatus(); -} - -static bool AllReduce(xla::runtime::ExecutionContext* ctx, void** args, - void** attrs, void** rets) { - static auto* handler = - CustomCall::Bind("xla.cpu.all_reduce") - .UserData() - .RemainingArgs() - .Attr>("replica_groups") - .Attr("channel_handle") - .Attr("use_global_device_ids") - .Attr("op_id") - .Attr("reduction_kind") - .To(XlaAllReduce::Handler()) - .release(); - return succeeded(Executable::Call(ctx, *handler, args, attrs, rets)); -} - -// -------------------------------------------------------------------------- // - -namespace { -struct XlaTupleAllToAll { - absl::Status operator()(const ExecutableRunOptions* run_options, - CustomCall::RemainingArgs buffers, - CustomCall::TensorRef replica_groups, - int32_t channel_id_present, int64_t op_id) const; - static XlaTupleAllToAll Handler() { return XlaTupleAllToAll(); } -}; -} // namespace - -absl::Status XlaTupleAllToAll::operator()( - const ExecutableRunOptions* run_options, CustomCall::RemainingArgs buffers, - CustomCall::TensorRef replica_groups, int32_t channel_id_present, - int64_t op_id) const { - if (replica_groups.shape.size() != 2) { - return absl::InvalidArgumentError("replica_groups must be a 2d tensor."); - } - - if (buffers.size() % 2) { - return absl::InvalidArgumentError( - "number of input buffers and output buffers must be equal."); - } - - std::string replica_groups_str = ReplicaGroupsToString(replica_groups); - int64_t num_buffers = static_cast(buffers.size()) / 2; - - llvm::SmallVector input_buffers, output_buffers; - for (int i = 0; i < num_buffers; ++i) { - auto input = buffers.get(i); - auto output = buffers.get(i + num_buffers); - if (!succeeded(input) || !succeeded(output)) { - return absl::InvalidArgumentError("all arguments must be memrefs."); - } - - input_buffers.push_back(input->data); - output_buffers.push_back(output->data); - } - - auto first_input = *buffers.get(0); - size_t buffer_size = ShapeUtil::ByteSizeOfElements( - ShapeUtil::MakeShape(first_input.dtype, first_input.sizes)); - - __xla_cpu_runtime_AllToAll( - run_options, channel_id_present, op_id, replica_groups_str.c_str(), - static_cast(replica_groups_str.size()), - static_cast(num_buffers), static_cast(buffer_size), - input_buffers.data(), output_buffers.data()); - - return absl::OkStatus(); -} - -static bool TupleAllToAll(xla::runtime::ExecutionContext* ctx, void** args, - void** attrs, void** rets) { - static auto* handler = - CustomCall::Bind("xla.cpu.all_reduce") - .UserData() - .RemainingArgs() - .Attr>("replica_groups") - .Attr("channel_id_present") - .Attr("op_id") - .To(XlaTupleAllToAll::Handler()) - .release(); - return succeeded(Executable::Call(ctx, *handler, args, attrs, rets)); -} - -// -------------------------------------------------------------------------- // - -namespace { -struct XlaCollectivePermute { - absl::Status operator()(const ExecutableRunOptions* run_options, - MemrefView input, MemrefView output, - CustomCall::TensorRef source_target_pairs, - int64_t channel_id) const; - static XlaCollectivePermute Handler() { return XlaCollectivePermute(); } -}; -} // namespace - -absl::Status XlaCollectivePermute::operator()( - const ExecutableRunOptions* run_options, MemrefView input, - MemrefView output, CustomCall::TensorRef source_target_pairs, - int64_t channel_id) const { - if (source_target_pairs.shape.size() != 2 || - source_target_pairs.shape[1] != 2) { - return absl::InvalidArgumentError( - "source_target_pairs must be a ?x2 tensor."); - } - size_t byte_size = ShapeUtil::ByteSizeOfElements( - ShapeUtil::MakeShape(input.dtype, input.sizes)); - std::string source_target_pairs_str = - SourceTargetPairsToString(source_target_pairs); - - __xla_cpu_runtime_CollectivePermute( - run_options, static_cast(channel_id), 0, - static_cast(byte_size), input.data, output.data, - source_target_pairs_str.c_str(), - static_cast(source_target_pairs_str.size())); - - return absl::OkStatus(); -} - -static bool CollectivePermute(xla::runtime::ExecutionContext* ctx, void** args, - void** attrs, void** rets) { - static auto* handler = - CustomCall::Bind("xla.cpu.collective_permute") - .UserData() - .Arg() // input - .Arg() // output - .Attr>("source_target_pairs") - .Attr("channel_handle") - .To(XlaCollectivePermute::Handler()) - .release(); - return succeeded(Executable::Call(ctx, *handler, args, attrs, rets)); -} - -void PopulateXlaCpuCollectivesCall( - xla::runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.cpu.all_reduce", &xla::cpu::AllReduce); - registry.Register("xla.cpu.tuple_all_to_all", &xla::cpu::TupleAllToAll); - registry.Register("xla.cpu.collective_permute", &xla::cpu::CollectivePermute); - registry.Register("xla.cpu.partition_id", &xla::cpu::PartitionId); - registry.Register("xla.cpu.replica_id", &xla::cpu::ReplicaId); -} - -} // namespace cpu -} // namespace xla diff --git a/third_party/xla/xla/service/cpu/runtime/collectives.h b/third_party/xla/xla/service/cpu/runtime/collectives.h deleted file mode 100644 index 043a3aaeb12223..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/collectives.h +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2022 The OpenXLA Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef XLA_SERVICE_CPU_RUNTIME_COLLECTIVES_H_ -#define XLA_SERVICE_CPU_RUNTIME_COLLECTIVES_H_ - -#include "xla/runtime/custom_call_registry.h" - -namespace xla { -namespace cpu { - -// Populate custom call implementing XLA CPU collectives -void PopulateXlaCpuCollectivesCall(runtime::DirectCustomCallRegistry& registry); - -} // namespace cpu -} // namespace xla - -#endif // XLA_SERVICE_CPU_RUNTIME_COLLECTIVES_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/convolution.cc b/third_party/xla/xla/service/cpu/runtime/convolution.cc deleted file mode 100644 index bc2c7ef29b2535..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/convolution.cc +++ /dev/null @@ -1,219 +0,0 @@ -// Copyright 2023 The OpenXLA Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "xla/service/cpu/runtime/convolution.h" - -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/types/span.h" -#include "Eigen/Core" // from @eigen_archive -#include "xla/executable_run_options.h" -#include "xla/runtime/memref_view.h" -#include "xla/service/cpu/runtime_conv2d.h" -#include "xla/service/cpu/runtime_conv3d.h" -#include "xla/xla_data.pb.h" - -namespace xla { -namespace cpu { - -using ::xla::runtime::MemrefView; - -absl::Status XlaConvolution::operator()( - const ExecutableRunOptions* run_options, MemrefView input, - MemrefView kernel, MemrefView output, int64_t inputBatchDimension, - absl::Span inputSpatialDimensions, - int64_t inputFeatureDimension, - absl::Span kernelSpatialDimensions, - int64_t kernelInputFeatureDimension, int64_t kernelOutputFeatureDimension, - absl::Span outputSpatialDimensions, - absl::Span window_strides, absl::Span padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, int64_t feature_group_count) const { - auto size = inputSpatialDimensions.size(); - if (size < 1 || size > 3) { - return absl::InvalidArgumentError( - "Only 1D, 2D and 3D convolutions are supported"); - } - - if (size != kernelSpatialDimensions.size() || - size != outputSpatialDimensions.size() || size != window_strides.size() || - size * 2 != padding.size() || size != lhs_dilation.size() || - size != rhs_dilation.size()) { - return absl::InvalidArgumentError("Number of attributes mismatched"); - } - - // We lower 1D convolutions into calls to the same Eigen function as 2D - // convolutions, except that we pretend that the 1D convolution is really a 2D - // convolution with the missing dimension set to 1. We also adjust the - // padding, dilation parameters as needed. - std::vector input_dims; - std::vector kernel_dims; - std::vector output_dims; - std::vector strides; - std::vector pad; - std::vector base_dilation; - std::vector window_dilation; - if (size == 1) { - input_dims.push_back(1); - kernel_dims.push_back(1); - output_dims.push_back(1); - strides.push_back(1); - pad.insert(pad.end(), {0, 0}); - base_dilation.push_back(1); - window_dilation.push_back(1); - } - for (auto dim : inputSpatialDimensions) { - input_dims.push_back(input.sizes[dim]); - } - for (auto dim : kernelSpatialDimensions) { - kernel_dims.push_back(kernel.sizes[dim]); - } - for (auto dim : outputSpatialDimensions) { - output_dims.push_back(output.sizes[dim]); - } - strides.insert(strides.end(), window_strides.begin(), window_strides.end()); - pad.insert(pad.end(), padding.begin(), padding.end()); - base_dilation.insert(base_dilation.end(), lhs_dilation.begin(), - lhs_dilation.end()); - window_dilation.insert(window_dilation.end(), rhs_dilation.begin(), - rhs_dilation.end()); - - if (output.dtype == PrimitiveType::F16) { - auto* out = reinterpret_cast(output.data); - auto* lhs = reinterpret_cast(input.data); - auto* rhs = reinterpret_cast(kernel.data); - if (size != 3) { - __xla_cpu_runtime_EigenConv2DF16( - run_options, out, lhs, rhs, - /*input_batch*/ input.sizes[inputBatchDimension], - /*input_rows*/ input_dims[0], - /*input_cols*/ input_dims[1], - /*input_channels*/ input.sizes[inputFeatureDimension], - /*kernel_rows*/ kernel_dims[0], - /*kernel_cols*/ kernel_dims[1], - /*kernel_channels*/ kernel.sizes[kernelInputFeatureDimension], - /*kernel_filters*/ kernel.sizes[kernelOutputFeatureDimension], - /*output_rows*/ output_dims[0], - /*output_cols*/ output_dims[1], - /*row_stride*/ strides[0], - /*col_stride*/ strides[1], - /*padding_top*/ pad[0], - /*padding_bottom*/ pad[1], - /*padding_left*/ pad[2], - /*padding_right*/ pad[3], - /*lhs_row_dilation*/ base_dilation[0], - /*lhs_col_dilation*/ base_dilation[1], - /*rhs_row_dilation*/ window_dilation[0], - /*rhs_col_dilation*/ window_dilation[1], feature_group_count); - } else { - __xla_cpu_runtime_EigenConv3DF16( - run_options, out, lhs, rhs, - /*input_batch*/ input.sizes[inputBatchDimension], - /*input_x*/ input_dims[0], - /*input_y*/ input_dims[1], - /*input_z*/ input_dims[2], - /*input_channels*/ input.sizes[inputFeatureDimension], - /*kernel_x*/ kernel_dims[0], - /*kernel_y*/ kernel_dims[1], - /*kernel_z*/ kernel_dims[2], - /*kernel_channels*/ kernel.sizes[kernelInputFeatureDimension], - /*kernel_filters*/ kernel.sizes[kernelOutputFeatureDimension], - /*output_x*/ output_dims[0], - /*output_y*/ output_dims[1], - /*output_z*/ output_dims[2], - /*x_stride*/ strides[0], - /*y_stride*/ strides[1], - /*z_stride*/ strides[2], - /*padding_x_before*/ pad[0], - /*padding_x_after*/ pad[1], - /*padding_y_before*/ pad[2], - /*padding_y_after*/ pad[3], - /*padding_z_before*/ pad[4], - /*padding_z_after*/ pad[5], - /*lhs_x_dilation*/ base_dilation[0], - /*lhs_y_dilation*/ base_dilation[1], - /*lhs_z_dilation*/ base_dilation[2], - /*rhs_x_dilation*/ window_dilation[0], - /*rhs_y_dilation*/ window_dilation[1], - /*rhs_z_dilation*/ window_dilation[2], feature_group_count); - } - } else { - auto* out = reinterpret_cast(output.data); - auto* lhs = reinterpret_cast(input.data); - auto* rhs = reinterpret_cast(kernel.data); - if (size != 3) { - __xla_cpu_runtime_EigenConv2DF32( - run_options, out, lhs, rhs, - /*input_batch*/ input.sizes[inputBatchDimension], - /*input_rows*/ input_dims[0], - /*input_cols*/ input_dims[1], - /*input_channels*/ input.sizes[inputFeatureDimension], - /*kernel_rows*/ kernel_dims[0], - /*kernel_cols*/ kernel_dims[1], - /*kernel_channels*/ kernel.sizes[kernelInputFeatureDimension], - /*kernel_filters*/ kernel.sizes[kernelOutputFeatureDimension], - /*output_rows*/ output_dims[0], - /*output_cols*/ output_dims[1], - /*row_stride*/ strides[0], - /*col_stride*/ strides[1], - /*padding_top*/ pad[0], - /*padding_bottom*/ pad[1], - /*padding_left*/ pad[2], - /*padding_right*/ pad[3], - /*lhs_row_dilation*/ base_dilation[0], - /*lhs_col_dilation*/ base_dilation[1], - /*rhs_row_dilation*/ window_dilation[0], - /*rhs_col_dilation*/ window_dilation[1], feature_group_count); - } else { - __xla_cpu_runtime_EigenConv3DF32( - run_options, out, lhs, rhs, - /*input_batch*/ input.sizes[inputBatchDimension], - /*input_x*/ input_dims[0], - /*input_y*/ input_dims[1], - /*input_z*/ input_dims[2], - /*input_channels*/ input.sizes[inputFeatureDimension], - /*kernel_x*/ kernel_dims[0], - /*kernel_y*/ kernel_dims[1], - /*kernel_z*/ kernel_dims[2], - /*kernel_channels*/ kernel.sizes[kernelInputFeatureDimension], - /*kernel_filters*/ kernel.sizes[kernelOutputFeatureDimension], - /*output_x*/ output_dims[0], - /*output_y*/ output_dims[1], - /*output_z*/ output_dims[2], - /*x_stride*/ strides[0], - /*y_stride*/ strides[1], - /*z_stride*/ strides[2], - /*padding_x_before*/ pad[0], - /*padding_x_after*/ pad[1], - /*padding_y_before*/ pad[2], - /*padding_y_after*/ pad[3], - /*padding_z_before*/ pad[4], - /*padding_z_after*/ pad[5], - /*lhs_x_dilation*/ base_dilation[0], - /*lhs_y_dilation*/ base_dilation[1], - /*lhs_z_dilation*/ base_dilation[2], - /*rhs_x_dilation*/ window_dilation[0], - /*rhs_y_dilation*/ window_dilation[1], - /*rhs_z_dilation*/ window_dilation[2], feature_group_count); - } - } - - return absl::OkStatus(); -} - -} // namespace cpu -} // namespace xla diff --git a/third_party/xla/xla/service/cpu/runtime/convolution.h b/third_party/xla/xla/service/cpu/runtime/convolution.h deleted file mode 100644 index fe4433774a7040..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/convolution.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2023 The OpenXLA Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#ifndef XLA_SERVICE_CPU_RUNTIME_CONVOLUTION_H_ -#define XLA_SERVICE_CPU_RUNTIME_CONVOLUTION_H_ - -#include - -#include "absl/status/status.h" -#include "absl/types/span.h" -#include "xla/executable_run_options.h" -#include "xla/runtime/memref_view.h" - -namespace xla { -namespace cpu { - -struct XlaConvolution { - absl::Status operator()( - const ExecutableRunOptions* run_options, xla::runtime::MemrefView input, - xla::runtime::MemrefView kernel, xla::runtime::MemrefView output, - int64_t inputBatchDimension, - absl::Span inputSpatialDimensions, - int64_t inputFeatureDimension, - absl::Span kernelSpatialDimensions, - int64_t kernelInputFeatureDimension, int64_t kernelOutputFeatureDimension, - absl::Span outputSpatialDimensions, - absl::Span window_strides, - absl::Span padding, absl::Span lhs_dilation, - absl::Span rhs_dilation, - int64_t feature_group_count) const; - static XlaConvolution Handler() { return XlaConvolution(); } -}; - -} // namespace cpu -} // namespace xla - -#endif // XLA_SERVICE_CPU_RUNTIME_CONVOLUTION_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/convolution_call.cc b/third_party/xla/xla/service/cpu/runtime/convolution_call.cc deleted file mode 100644 index 793f6285da40c1..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/convolution_call.cc +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright 2023 The OpenXLA Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "xla/service/cpu/runtime/convolution_call.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/types/span.h" -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "xla/executable_run_options.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/runtime/executable.h" -#include "xla/runtime/memref_view.h" -#include "xla/service/cpu/runtime/convolution.h" - -namespace xla { -namespace cpu { - -using ::xla::runtime::CustomCall; -using ::xla::runtime::Executable; -using ::xla::runtime::MemrefView; - -// Disable all CustomCall checks in optimized build. -static constexpr CustomCall::RuntimeChecks RuntimeChecks() { -#if defined(NDEBUG) - return CustomCall::RuntimeChecks::kNone; -#else - return CustomCall::RuntimeChecks::kDefault; -#endif -} - -static bool Convolution(xla::runtime::ExecutionContext* ctx, void** args, - void** attrs, void** rets) { - static auto* handler = - CustomCall::Bind("xla_cpu_convolution") - .UserData() - .Arg() // input - .Arg() // kernel - .Arg() // output - .Attr("inputBatchDimension") - .Attr>("inputSpatialDimensions") - .Attr("inputFeatureDimension") - .Attr>("kernelSpatialDimensions") - .Attr("kernelInputFeatureDimension") - .Attr("kernelOutputFeatureDimension") - .Attr>("outputSpatialDimensions") - .Attr>("window_strides") - .Attr>("padding") - .Attr>("lhs_dilation") - .Attr>("rhs_dilation") - .Attr("feature_group_count") - .To(xla::cpu::XlaConvolution::Handler()) - .release(); - return succeeded(Executable::Call(ctx, *handler, args, attrs, rets)); -} - -void PopulateXlaCpuConvolutionCall( - xla::runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla_cpu_convolution", &Convolution); -} - -} // namespace cpu -} // namespace xla diff --git a/third_party/xla/xla/service/cpu/runtime/convolution_call.h b/third_party/xla/xla/service/cpu/runtime/convolution_call.h deleted file mode 100644 index 07bc96c51b4bfc..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/convolution_call.h +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2023 The OpenXLA Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#ifndef XLA_SERVICE_CPU_RUNTIME_CONVOLUTION_CALL_H_ -#define XLA_SERVICE_CPU_RUNTIME_CONVOLUTION_CALL_H_ - -#include "xla/runtime/custom_call_registry.h" - -namespace xla { -namespace cpu { - -// Populate custom call implementing XLA CPU Convolution. -void PopulateXlaCpuConvolutionCall(runtime::DirectCustomCallRegistry& registry); - -} // namespace cpu -} // namespace xla - -#endif // XLA_SERVICE_CPU_RUNTIME_CONVOLUTION_CALL_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/convolution_ffi.cc b/third_party/xla/xla/service/cpu/runtime/convolution_ffi.cc deleted file mode 100644 index 9673938a05eea9..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/convolution_ffi.cc +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright 2023 The OpenXLA Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "xla/service/cpu/runtime/convolution_ffi.h" - -#include "absl/status/status.h" -#include "absl/types/span.h" -#include "xla/runtime/aot_ffi.h" -#include "xla/runtime/aot_ffi_execution_context.h" -#include "xla/runtime/ffi/ffi_api.h" -#include "xla/runtime/ffi/ffi_c_api.h" -#include "xla/runtime/memref_view.h" -#include "xla/service/cpu/runtime/convolution.h" -#include "xla/xla_data.pb.h" - -namespace xla { -struct ExecutableRunOptions; -} // namespace xla - -namespace aot = ::xla::runtime::aot; -namespace ffi = ::xla::runtime::ffi; - -namespace { - -using ::xla::runtime::MemrefView; - -ffi::FfiStatus ConvolutionFfi( - xla::ExecutableRunOptions* executable_run_options, ffi::BufferArg input, - ffi::BufferArg kernel, ffi::BufferArg output, int64_t inputBatchDimension, - ffi::Span inputSpatialDimensions, - int64_t inputFeatureDimension, - ffi::Span kernelSpatialDimensions, - int64_t kernelInputFeatureDimension, int64_t kernelOutputFeatureDimension, - ffi::Span outputSpatialDimensions, - ffi::Span window_strides, ffi::Span padding, - ffi::Span lhs_dilation, - ffi::Span rhs_dilation, int64_t feature_group_count) { - auto to_memref_view = [](const ffi::BufferArg& view) -> MemrefView { - auto dtype = static_cast(view.dtype); - return MemrefView{ - dtype, view.data, - absl::MakeConstSpan(view.sizes.begin(), view.sizes.end())}; - }; - auto to_span = - [](ffi::Span span) -> absl::Span { - return absl::MakeConstSpan(span.begin(), span.end()); - }; - - xla::cpu::XlaConvolution convolution; - absl::Status status = convolution( - executable_run_options, to_memref_view(input), to_memref_view(kernel), - to_memref_view(output), inputBatchDimension, - to_span(inputSpatialDimensions), inputFeatureDimension, - to_span(kernelSpatialDimensions), kernelInputFeatureDimension, - kernelOutputFeatureDimension, to_span(outputSpatialDimensions), - to_span(window_strides), to_span(padding), to_span(lhs_dilation), - to_span(rhs_dilation), feature_group_count); - return status.ok() ? ffi::FfiStatus::Ok() : ffi::FfiStatus::Internal("err"); -} - -XLA_FFI_DEFINE_FUNCTION( - FFI_Convolution, ConvolutionFfi, - ffi::Ffi::Binding() - .ApiPriv() - .Arg() // input - .Arg() // kernel - .Arg() // output - .Attr("inputBatchDimension") - .Attr>("inputSpatialDimensions") - .Attr("inputFeatureDimension") - .Attr>("kernelSpatialDimensions") - .Attr("kernelInputFeatureDimension") - .Attr("kernelOutputFeatureDimension") - .Attr>("outputSpatialDimensions") - .Attr>("window_strides") - .Attr>("padding") - .Attr>("lhs_dilation") - .Attr>("rhs_dilation") - .Attr("feature_group_count")); - -} // namespace - -bool xla_cpu_convolution(void* execution_context, void** args, void** attrs, - void** rets) { - auto ctx = static_cast(execution_context); - void* executable_run_options = ctx->custom_call_data; - - XLA_FFI_Api api = aot::FfiApi(); - api.priv = executable_run_options; - - XLA_FFI_Function_Args ffi_args = aot::FfiArgs(&api, args, attrs, rets); - - XLA_FFI_Error* error = FFI_Convolution(&ffi_args); - return aot::ProcessErrorIfAny(error); -} diff --git a/third_party/xla/xla/service/cpu/runtime/convolution_ffi.h b/third_party/xla/xla/service/cpu/runtime/convolution_ffi.h deleted file mode 100644 index 7ca9319269a547..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/convolution_ffi.h +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright 2023 The OpenXLA Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#ifndef XLA_SERVICE_CPU_RUNTIME_CONVOLUTION_FFI_H_ -#define XLA_SERVICE_CPU_RUNTIME_CONVOLUTION_FFI_H_ - -extern "C" { -bool xla_cpu_convolution(void* execution_context, void** args, void** attrs, - void** rets); -} // extern "C" - -#endif // XLA_SERVICE_CPU_RUNTIME_CONVOLUTION_FFI_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/custom_call.cc b/third_party/xla/xla/service/cpu/runtime/custom_call.cc deleted file mode 100644 index 6b45f3d1a36718..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/custom_call.cc +++ /dev/null @@ -1,177 +0,0 @@ -// Copyright 2022 The OpenXLA Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "xla/service/cpu/runtime/custom_call.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "xla/primitive_util.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/runtime/executable.h" -#include "xla/runtime/memref_view.h" -#include "xla/service/custom_call_status.h" -#include "xla/service/custom_call_status_internal.h" -#include "xla/service/custom_call_target_registry.h" -#include "xla/service/hlo.pb.h" -#include "xla/xla.pb.h" - -namespace xla { -namespace cpu { - -using mlir::StringRef; -using mlir::succeeded; - -using ::xla::runtime::CustomCall; -using ::xla::runtime::Executable; - -// Disable all CustomCall checks in optimized build. -static constexpr CustomCall::RuntimeChecks RuntimeChecks() { -#if defined(NDEBUG) - return CustomCall::RuntimeChecks::kNone; -#else - return CustomCall::RuntimeChecks::kDefault; -#endif -} - -// -------------------------------------------------------------------------- // - -namespace { -struct XlaCustomCall { - absl::Status operator()(CustomCall::RemainingArgs args, int32_t num_results, - bool output_tuple, StringRef call_target_name, - int32_t api_version, StringRef backend_config) const; - static XlaCustomCall Handler() { return XlaCustomCall(); } -}; -} // namespace - -absl::Status XlaCustomCall::operator()(CustomCall::RemainingArgs args, - int32_t num_results, bool output_tuple, - StringRef call_target_name, - int32_t api_version, - StringRef backend_config) const { - // Find the Xla custom call handler. - void* call_target = CustomCallTargetRegistry::Global()->Lookup( - call_target_name.str(), "Host"); - if (!call_target) { - return absl::InvalidArgumentError(absl::StrCat( - "Cannot find the Xla custom call handler ", call_target_name.str())); - } - - // Prepare pointers to buffers to pass to the Xla custom call handler. - llvm::SmallVector buffers; - for (unsigned i = 0; i < args.size(); ++i) { - // We use zero-sized memrefs to represent holes in custom calls with target - // arguments mapping (see `CustomCallTargetArgMapping`). - if (auto memref = args.get(i); succeeded(memref)) { - buffers.push_back(memref->size_in_bytes == 0 ? nullptr : memref->data); - continue; - } - if (auto strided = args.get(i); - succeeded(strided)) { - int64_t size_in_bytes = primitive_util::ByteWidth(strided->dtype); - for (int64_t size : strided->sizes) size_in_bytes *= size; - buffers.push_back(size_in_bytes == 0 ? nullptr : strided->data); - continue; - } - return absl::InvalidArgumentError( - "Failed to get arguments as (strided) memref view"); - } - - // Multiple result buffers are passed as a tuple, which is represented as a - // buffer of pointers. - void* result_buffer = - !output_tuple ? buffers.back() : buffers.end() - num_results; - - // Original custom call API version that doesn't support returning status. - if (api_version == CustomCallApiVersion::API_VERSION_ORIGINAL) { - using XlaCustomCallType = void (*)(void* /*result*/, void** /*args*/); - auto xla_call_target = reinterpret_cast(call_target); - - xla_call_target(result_buffer, buffers.data()); - - return absl::OkStatus(); - } - - // Xla Custom call API returning status. - if (api_version == CustomCallApiVersion::API_VERSION_STATUS_RETURNING) { - using XlaCustomCallType = void (*)(void* /*result*/, void** /*args*/, - XlaCustomCallStatus* /*status*/); - auto xla_call_target = reinterpret_cast(call_target); - - XlaCustomCallStatus custom_call_status; - xla_call_target(result_buffer, buffers.data(), &custom_call_status); - - if (auto message = CustomCallStatusGetMessage(&custom_call_status)) { - return absl::InternalError(message.value()); - } else { - return absl::OkStatus(); - } - } - - if (api_version == - CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED) { - using XlaCustomCallType = - void (*)(void* /*result*/, void** /*args*/, const char*, size_t, - XlaCustomCallStatus* /*status*/); - auto xla_call_target = reinterpret_cast(call_target); - - XlaCustomCallStatus custom_call_status; - xla_call_target(result_buffer, buffers.data(), backend_config.data(), - backend_config.size(), &custom_call_status); - - if (auto message = CustomCallStatusGetMessage(&custom_call_status)) { - return absl::InternalError(message.value()); - } else { - return absl::OkStatus(); - } - } - - return absl::InvalidArgumentError("Incorrect custom call API version"); -} - -static bool CustomCall(runtime::ExecutionContext* ctx, void** args, - void** attrs, void** rets) { - static auto* handler = CustomCall::Bind("xla.cpu.custom_call") - .Arg() // args - .Attr("num_results") - .Attr("output_tuple") - .Attr("call_target_name") - .Attr("api_version") - .Attr("backend_config") - .To(XlaCustomCall::Handler()) - .release(); - return succeeded(Executable::Call(ctx, *handler, args, attrs, rets)); -} - -void PopulateXlaCpuCustomCall(runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.cpu.custom_call", &xla::cpu::CustomCall); -} - -} // namespace cpu -} // namespace xla diff --git a/third_party/xla/xla/service/cpu/runtime/custom_call.h b/third_party/xla/xla/service/cpu/runtime/custom_call.h deleted file mode 100644 index c4992e60cb9a2d..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/custom_call.h +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2022 The OpenXLA Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef XLA_SERVICE_CPU_RUNTIME_CUSTOM_CALL_H_ -#define XLA_SERVICE_CPU_RUNTIME_CUSTOM_CALL_H_ - -#include "xla/runtime/custom_call_registry.h" - -namespace xla { -namespace cpu { - -// Populate custom call implementing XLA CPU runtime API for the legacy ABI. -void PopulateXlaCpuCustomCall(runtime::DirectCustomCallRegistry& registry); - -} // namespace cpu -} // namespace xla - -#endif // XLA_SERVICE_CPU_RUNTIME_CUSTOM_CALL_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/fft_call.cc b/third_party/xla/xla/service/cpu/runtime/fft_call.cc deleted file mode 100644 index c62b57422a1543..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/fft_call.cc +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright 2022 The OpenXLA Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "xla/service/cpu/runtime/fft_call.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/inlined_vector.h" -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "absl/types/span.h" -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "xla/executable_run_options.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/runtime/executable.h" -#include "xla/runtime/memref_view.h" -#include "xla/service/cpu/runtime_fft.h" -#include "xla/service/hlo.pb.h" -#include "xla/xla.pb.h" -#include "xla/xla_data.pb.h" - -namespace xla { -namespace cpu { - -using ::xla::runtime::CustomCall; -using ::xla::runtime::Executable; -using ::xla::runtime::MemrefView; - -// Disable all CustomCall checks in optimized build. -static constexpr CustomCall::RuntimeChecks RuntimeChecks() { -#if defined(NDEBUG) - return CustomCall::RuntimeChecks::kNone; -#else - return CustomCall::RuntimeChecks::kDefault; -#endif -} - -namespace { -struct XlaFft { - absl::Status operator()(const ExecutableRunOptions* run_options, - MemrefView input, MemrefView output, int32_t fft_type, - absl::Span fft_length) const; - static XlaFft Handler() { return XlaFft(); } -}; -} // namespace - -absl::Status XlaFft::operator()(const ExecutableRunOptions* run_options, - MemrefView input, MemrefView output, - int32_t fft_type, - absl::Span fft_length) const { - bool double_precision = output.dtype == PrimitiveType::C128; - auto fft_rank = static_cast(fft_length.size()); - if (fft_length.empty() || fft_length.size() > input.sizes.length()) { - return absl::InvalidArgumentError(absl::StrCat( - "fft_length must contain between 1 and ", input.sizes.length(), - " elements for an input with rank ", input.sizes.length())); - } - - // Flatten batch dimensions. - absl::InlinedVector input_sizes(fft_rank + 1); - int64_t input_batch = 1; - int64_t dim_offset = input.sizes.size() - fft_rank; - for (int64_t dim = 0; dim < dim_offset; ++dim) { - input_batch *= input.sizes[dim]; - } - input_sizes[0] = input_batch; - for (int64_t dim = 0; dim < fft_rank; ++dim) { - input_sizes[1 + dim] = input.sizes[dim_offset + dim]; - } - __xla_cpu_runtime_DuccFft(run_options, output.data, input.data, fft_type, - static_cast(double_precision), fft_rank, - input_sizes.data(), fft_length.data()); - return absl::OkStatus(); -} - -static bool Fft(xla::runtime::ExecutionContext* ctx, void** args, void** attrs, - void** rets) { - static auto* handler = CustomCall::Bind("xla.cpu.fft") - .UserData() - .Arg() // input - .Arg() // output - .Attr("fft_type") - .Attr>("fft_length") - .To(XlaFft::Handler()) - .release(); - return succeeded(Executable::Call(ctx, *handler, args, attrs, rets)); -} - -void PopulateXlaCpuFftCall(xla::runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.cpu.fft", &Fft); -} - -} // namespace cpu -} // namespace xla diff --git a/third_party/xla/xla/service/cpu/runtime/fft_call.h b/third_party/xla/xla/service/cpu/runtime/fft_call.h deleted file mode 100644 index 7e728824fefd10..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/fft_call.h +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2022 The OpenXLA Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#ifndef XLA_SERVICE_CPU_RUNTIME_FFT_CALL_H_ -#define XLA_SERVICE_CPU_RUNTIME_FFT_CALL_H_ - -#include "xla/runtime/custom_call_registry.h" - -namespace xla { -namespace cpu { - -// Populate custom call implementing XLA CPU FFT. -void PopulateXlaCpuFftCall(runtime::DirectCustomCallRegistry& registry); - -} // namespace cpu -} // namespace xla - -#endif // XLA_SERVICE_CPU_RUNTIME_FFT_CALL_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/retain.cc b/third_party/xla/xla/service/cpu/runtime/retain.cc deleted file mode 100644 index 431c0f75a8c8c4..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/retain.cc +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -extern "C" void retainBuffers(int64_t numAllocs, void** allocBuffers, - int64_t numRetained, void** retainedBuffers) { - for (int64_t i = 0; i < numRetained; ++i) { - void* retained = retainedBuffers[i]; - retainedBuffers[i] = nullptr; - for (int64_t j = 0; j < numAllocs; ++j) { - if (allocBuffers[j] == retained) { - std::swap(allocBuffers[j], retainedBuffers[i]); - break; - } - } - } - - for (int64_t i = 0; i < numAllocs; ++i) { - if (allocBuffers[i]) { - free(allocBuffers[i]); - } - } -} diff --git a/third_party/xla/xla/service/cpu/runtime/rng.cc b/third_party/xla/xla/service/cpu/runtime/rng.cc deleted file mode 100644 index 7f2edd42b56b26..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/rng.cc +++ /dev/null @@ -1,201 +0,0 @@ -// Copyright 2023 The OpenXLA Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "xla/service/cpu/runtime/rng.h" - -#include -#include - -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "xla/executable_run_options.h" -#include "xla/runtime/memref_view.h" -#include "xla/xla_data.pb.h" - -namespace xla { -namespace cpu { - -using ::xla::runtime::FlatMemrefView; - -static std::array threefry2x32(std::array key, - std::array ctr) { - constexpr std::array, 2> rotations{ - std::array{13, 15, 26, 6}, std::array{17, 29, 16, 24}}; - - std::array ks{key[0], key[1], key[0] ^ key[1] ^ 0x1BD11BDAu}; - ctr[0] += ks[0]; - ctr[1] += ks[1]; - - auto apply_round = [&](int r, int i0, int i1, int b) { - for (int64_t rot : rotations[r]) { - ctr[0] += ctr[1]; - ctr[1] = (ctr[1] << rot) | (ctr[1] >> (32 - rot)); - ctr[1] ^= ctr[0]; - } - ctr[0] += ks[i0]; - ctr[1] += ks[i1] + b; - }; - - apply_round(0, 1, 2, 1); - apply_round(1, 2, 0, 2); - apply_round(0, 0, 1, 3); - apply_round(1, 1, 2, 4); - apply_round(0, 2, 0, 5); - return ctr; -} - -static std::array philox4x32(std::array key, - std::array ctr) { - auto mulhilo = [](uint64_t a, uint64_t b) -> std::array { - return {static_cast((a * b) >> 32), static_cast(a * b)}; - }; - for (int i = 0; i < 10; ++i) { - auto [hi0, lo0] = mulhilo(0xD2511F53, ctr[0]); - auto [hi1, lo1] = mulhilo(0xCD9E8D57, ctr[2]); - ctr = {{hi1 ^ ctr[1] ^ key[0], lo1, hi0 ^ ctr[3] ^ key[1], lo0}}; - key[0] += 0x9E3779B9u; - key[1] += 0xBB67AE85u; - } - return ctr; -} - -template -void FillBuffer(void* buffer, void* state_buffer, int64_t size_bytes, T fn, - C ctr, std::array key) { - E* out = static_cast(buffer); - int64_t i = 0; - int64_t num = size_bytes / sizeof(E); - while (i < num) { - auto val = fn(key, ctr); - for (int64_t j = 0; j < val.size() && i < num; ++i, ++j) { - out[i] = val[j]; - } - if (!++ctr[0]) { - ++ctr[1]; - } - } - - auto state_out = static_cast(state_buffer); - state_out[0] = key[0]; - state_out[1] = key[1]; - state_out[2] = ctr[0]; - state_out[3] = ctr[1]; -} - -static absl::Status ValidateStateBuffers(FlatMemrefView state_buffer, - FlatMemrefView state_out_buffer, - bool allow_24 = false) { - if (state_buffer.size_in_bytes != 16 && - !(allow_24 && state_buffer.size_in_bytes == 24)) { - return absl::InvalidArgumentError( - absl::StrCat("Unexpected state size: ", state_buffer.size_in_bytes)); - } - if (state_out_buffer.size_in_bytes != state_buffer.size_in_bytes) { - return absl::InvalidArgumentError( - "Expected state output to have the same size as input."); - } - return absl::OkStatus(); -} - -absl::Status XlaThreeFry::operator()(const ExecutableRunOptions*, - FlatMemrefView state_buffer, - FlatMemrefView state_out_buffer, - FlatMemrefView values_buffer) const { - auto status = ValidateStateBuffers(state_buffer, state_out_buffer); - if (!status.ok()) { - return status; - } - - auto* state_vals = static_cast(state_buffer.data); - std::array key{state_vals[0], state_vals[1]}; - std::array ctr{state_vals[2], state_vals[3]}; - - switch (values_buffer.dtype) { - case S8: - case U8: - FillBuffer(values_buffer.data, state_out_buffer.data, - values_buffer.size_in_bytes, threefry2x32, ctr, key); - break; - case F16: - case U16: - case S16: - // XLA's RngBitGeneratorExpander has a corner case for bit widths less - // than 32 where it discards half the bits. We don't really need that, but - // some TF tests depend on it, somehow. - FillBuffer(values_buffer.data, state_out_buffer.data, - values_buffer.size_in_bytes, threefry2x32, ctr, key); - break; - case F32: - case U32: - case S32: - case F64: - case U64: - case S64: - FillBuffer(values_buffer.data, state_out_buffer.data, - values_buffer.size_in_bytes, threefry2x32, ctr, key); - break; - default: - return absl::UnimplementedError( - "Type not implemented by ThreeFryBitGenerator"); - } - - return absl::OkStatus(); -} - -absl::Status XlaPhilox::operator()(const ExecutableRunOptions*, - FlatMemrefView state_buffer, - FlatMemrefView state_out_buffer, - FlatMemrefView values_buffer) const { - auto status = ValidateStateBuffers(state_buffer, state_out_buffer, true); - if (!status.ok()) { - return status; - } - - auto* state_vals = static_cast(state_buffer.data); - std::array key{state_vals[0], state_vals[1]}; - bool is_24 = state_buffer.size_in_bytes == 24; - std::array ctr{state_vals[2], state_vals[3], - state_vals[is_24 ? 4 : 0], - state_vals[is_24 ? 5 : 1]}; - - switch (values_buffer.dtype) { - case S8: - case U8: - FillBuffer(values_buffer.data, state_out_buffer.data, - values_buffer.size_in_bytes, philox4x32, ctr, key); - break; - case F16: - case U16: - case S16: - FillBuffer(values_buffer.data, state_out_buffer.data, - values_buffer.size_in_bytes, philox4x32, ctr, key); - break; - case F32: - case U32: - case S32: - case F64: - case U64: - case S64: - FillBuffer(values_buffer.data, state_out_buffer.data, - values_buffer.size_in_bytes, philox4x32, ctr, key); - break; - default: - return absl::UnimplementedError( - "Type not implemented by PhiloxBitGenerator"); - } - return absl::OkStatus(); -} - -} // namespace cpu -} // namespace xla diff --git a/third_party/xla/xla/service/cpu/runtime/rng.h b/third_party/xla/xla/service/cpu/runtime/rng.h deleted file mode 100644 index dc724ec15eb8ff..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/rng.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2023 The OpenXLA Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef XLA_SERVICE_CPU_RUNTIME_RNG_H_ -#define XLA_SERVICE_CPU_RUNTIME_RNG_H_ - -#include - -#include "absl/status/status.h" -#include "xla/executable_run_options.h" -#include "xla/runtime/memref_view.h" - -namespace xla { -namespace cpu { - -struct XlaThreeFry { - absl::Status operator()(const ExecutableRunOptions*, - xla::runtime::FlatMemrefView state_buffer, - xla::runtime::FlatMemrefView state_out_buffer, - xla::runtime::FlatMemrefView values_buffer) const; - static XlaThreeFry Handler() { return XlaThreeFry(); } -}; - -struct XlaPhilox { - absl::Status operator()(const ExecutableRunOptions*, - xla::runtime::FlatMemrefView state_buffer, - xla::runtime::FlatMemrefView state_out_buffer, - xla::runtime::FlatMemrefView values_buffer) const; - static XlaPhilox Handler() { return XlaPhilox(); } -}; - -} // namespace cpu -} // namespace xla - -#endif // XLA_SERVICE_CPU_RUNTIME_RNG_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/rng_call.cc b/third_party/xla/xla/service/cpu/runtime/rng_call.cc deleted file mode 100644 index 6bcbe0fe0bf7e4..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/rng_call.cc +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright 2023 The OpenXLA Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "xla/service/cpu/runtime/rng_call.h" - -#include -#include - -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "xla/executable_run_options.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/runtime/executable.h" -#include "xla/runtime/memref_view.h" -#include "xla/service/cpu/runtime/rng.h" - -namespace xla { -namespace cpu { - -using ::xla::runtime::CustomCall; -using ::xla::runtime::Executable; -using ::xla::runtime::FlatMemrefView; - -// Disable all CustomCall checks in optimized build. -static constexpr CustomCall::RuntimeChecks RuntimeChecks() { -#if defined(NDEBUG) - return CustomCall::RuntimeChecks::kNone; -#else - return CustomCall::RuntimeChecks::kDefault; -#endif -} - -static bool ThreeFry(xla::runtime::ExecutionContext* ctx, void** args, - void** attrs, void** rets) { - static auto* handler = - CustomCall::Bind("xla_cpu_rng_three_fry") - .UserData() - .Arg() - .Arg() - .Arg() - .To(xla::cpu::XlaThreeFry::Handler()) - .release(); - return succeeded(Executable::Call(ctx, *handler, args, attrs, rets)); -} - -static bool Philox(xla::runtime::ExecutionContext* ctx, void** args, - void** attrs, void** rets) { - static auto* handler = - CustomCall::Bind("xla_cpu_rng_philox") - .UserData() - .Arg() - .Arg() - .Arg() - .To(xla::cpu::XlaPhilox::Handler()) - .release(); - return succeeded(Executable::Call(ctx, *handler, args, attrs, rets)); -} - -void PopulateXlaCpuRngCall(xla::runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla_cpu_rng_three_fry", &ThreeFry); - registry.Register("xla_cpu_rng_philox", &Philox); -} - -} // namespace cpu -} // namespace xla diff --git a/third_party/xla/xla/service/cpu/runtime/rng_call.h b/third_party/xla/xla/service/cpu/runtime/rng_call.h deleted file mode 100644 index f189b90084076c..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/rng_call.h +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2023 The OpenXLA Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef XLA_SERVICE_CPU_RUNTIME_RNG_CALL_H_ -#define XLA_SERVICE_CPU_RUNTIME_RNG_CALL_H_ - -#include "xla/runtime/custom_call_registry.h" - -namespace xla { -namespace cpu { - -// Populate custom call implementing XLA CPU RNGs. -void PopulateXlaCpuRngCall(runtime::DirectCustomCallRegistry& registry); - -} // namespace cpu -} // namespace xla - -#endif // XLA_SERVICE_CPU_RUNTIME_RNG_CALL_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/rng_ffi.cc b/third_party/xla/xla/service/cpu/runtime/rng_ffi.cc deleted file mode 100644 index 8efd9aabfade06..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/rng_ffi.cc +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright 2023 The OpenXLA Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "xla/service/cpu/runtime/rng_ffi.h" - -#include "absl/status/status.h" -#include "xla/runtime/aot_ffi.h" -#include "xla/runtime/aot_ffi_execution_context.h" -#include "xla/runtime/ffi/ffi_api.h" -#include "xla/runtime/ffi/ffi_c_api.h" -#include "xla/runtime/memref_view.h" -#include "xla/service/cpu/runtime/rng.h" -#include "xla/xla_data.pb.h" - -namespace xla { -struct ExecutableRunOptions; -} // namespace xla - -namespace aot = ::xla::runtime::aot; -namespace ffi = ::xla::runtime::ffi; - -namespace { - -using ::xla::runtime::FlatMemrefView; - -// Converts an ffi::FlatBufferArg to an xla::runtime::FlatMemrefView. -FlatMemrefView ToFlatMemrefView(const ffi::FlatBufferArg& view) { - auto dtype = static_cast(view.dtype); - return FlatMemrefView{dtype, view.data, view.size_in_bytes}; -} - -ffi::FfiStatus ThreeFryFfi(xla::ExecutableRunOptions* executable_run_options, - ffi::FlatBufferArg state_buffer, - ffi::FlatBufferArg state_out_buffer, - ffi::FlatBufferArg values_buffer) { - xla::cpu::XlaThreeFry three_fry; - absl::Status status = three_fry( - executable_run_options, ToFlatMemrefView(state_buffer), - ToFlatMemrefView(state_out_buffer), ToFlatMemrefView(values_buffer)); - return status.ok() ? ffi::FfiStatus::Ok() : ffi::FfiStatus::Internal("err"); -} - -XLA_FFI_DEFINE_FUNCTION(FFI_ThreeFry, ThreeFryFfi, - ffi::Ffi::Binding() - .ApiPriv() - .Arg() // state_buffer - .Arg() // state_out_buffer - .Arg()); // values_buffer - -ffi::FfiStatus PhiloxFfi(xla::ExecutableRunOptions* executable_run_options, - ffi::FlatBufferArg state_buffer, - ffi::FlatBufferArg state_out_buffer, - ffi::FlatBufferArg values_buffer) { - xla::cpu::XlaPhilox philox; - absl::Status status = philox( - executable_run_options, ToFlatMemrefView(state_buffer), - ToFlatMemrefView(state_out_buffer), ToFlatMemrefView(values_buffer)); - return status.ok() ? ffi::FfiStatus::Ok() : ffi::FfiStatus::Internal("err"); -} - -XLA_FFI_DEFINE_FUNCTION(FFI_Philox, PhiloxFfi, - ffi::Ffi::Binding() - .ApiPriv() - .Arg() // state_buffer - .Arg() // state_out_buffer - .Arg()); // values_buffer - -} // namespace - -bool xla_cpu_rng_three_fry(void* execution_context, void** args, void** attrs, - void** rets) { - auto ctx = static_cast(execution_context); - void* executable_run_options = ctx->custom_call_data; - - XLA_FFI_Api api = aot::FfiApi(); - api.priv = executable_run_options; - - XLA_FFI_Function_Args ffi_args = aot::FfiArgs(&api, args, attrs, rets); - - XLA_FFI_Error* error = FFI_ThreeFry(&ffi_args); - return aot::ProcessErrorIfAny(error); -} - -bool xla_cpu_rng_philox(void* execution_context, void** args, void** attrs, - void** rets) { - auto ctx = static_cast(execution_context); - void* executable_run_options = ctx->custom_call_data; - - XLA_FFI_Api api = aot::FfiApi(); - api.priv = executable_run_options; - - XLA_FFI_Function_Args ffi_args = aot::FfiArgs(&api, args, attrs, rets); - - XLA_FFI_Error* error = FFI_Philox(&ffi_args); - return aot::ProcessErrorIfAny(error); -} diff --git a/third_party/xla/xla/service/cpu/runtime/rng_ffi.h b/third_party/xla/xla/service/cpu/runtime/rng_ffi.h deleted file mode 100644 index 4383f96ae45205..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/rng_ffi.h +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2023 The OpenXLA Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#ifndef XLA_SERVICE_CPU_RUNTIME_RNG_FFI_H_ -#define XLA_SERVICE_CPU_RUNTIME_RNG_FFI_H_ - -extern "C" { -bool xla_cpu_rng_three_fry(void* execution_context, void** args, void** attrs, - void** rets); -bool xla_cpu_rng_philox(void* execution_context, void** args, void** attrs, - void** rets); -} // extern "C" - -#endif // XLA_SERVICE_CPU_RUNTIME_RNG_FFI_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/xfeed.cc b/third_party/xla/xla/service/cpu/runtime/xfeed.cc deleted file mode 100644 index 38bb2eb34644e3..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/xfeed.cc +++ /dev/null @@ -1,189 +0,0 @@ -// Copyright 2022 The OpenXLA Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "xla/service/cpu/runtime/xfeed.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/types/span.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "xla/executable_run_options.h" -#include "xla/primitive_util.h" -#include "xla/runtime/custom_call.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/runtime/executable.h" -#include "xla/runtime/memref_view.h" -#include "xla/service/cpu/cpu_runtime.h" -#include "xla/shape_util.h" -#include "xla/xla_data.pb.h" - -namespace xla { -namespace cpu { - -using mlir::succeeded; - -using ::xla::runtime::CustomCall; -using ::xla::runtime::Executable; - -// Disable all CustomCall checks in optimized build. -static constexpr CustomCall::RuntimeChecks RuntimeChecks() { -#if defined(NDEBUG) - return CustomCall::RuntimeChecks::kNone; -#else - return CustomCall::RuntimeChecks::kDefault; -#endif -} - -static xla::Shape ToShape(const xla::runtime::StridedMemrefView& memref) { - // Recover `minor_to_major` dimensions permutation from strides. - auto indexed_strides_range = - llvm::map_range(llvm::enumerate(memref.strides), [](auto pair) { - return std::pair{pair.value(), pair.index()}; - }); - - auto indexed_strides = llvm::to_vector(indexed_strides_range); - llvm::stable_sort(indexed_strides); - - auto minor_to_major = - llvm::to_vector(llvm::make_second_range(indexed_strides)); - return xla::ShapeUtil::MakeShapeWithDenseLayout(memref.dtype, memref.sizes, - minor_to_major); -} - -static int64_t MemrefSize(const xla::runtime::StridedMemrefView& memref) { - int64_t size_in_bytes = primitive_util::ByteWidth(memref.dtype); - for (int64_t size : memref.sizes) { - size_in_bytes *= size; - } - return size_in_bytes; -} - -// -------------------------------------------------------------------------- // - -namespace { -struct XlaInfeed { - absl::Status operator()(const ExecutableRunOptions* run_options, - CustomCall::RemainingArgs args) const; - static XlaInfeed Handler() { return XlaInfeed(); } -}; -} // namespace - -absl::Status XlaInfeed::operator()(const ExecutableRunOptions* run_options, - CustomCall::RemainingArgs args) const { - for (unsigned i = 0; i < args.size(); ++i) { - auto memref = args.get(i); - if (!succeeded(memref)) { - return absl::InvalidArgumentError( - "Failed to get arguments as (strided) memref view"); - } - - auto size_in_bytes = static_cast(MemrefSize(*memref)); - std::string shape_string = ToShape(*memref).SerializeAsString(); - - void* infeed_buffer = __xla_cpu_runtime_AcquireInfeedBufferForDequeue( - run_options, size_in_bytes, shape_string.data(), - static_cast(shape_string.size())); - // Copy from the infeed buffer. - std::memcpy(memref->data, infeed_buffer, size_in_bytes); - __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue( - run_options, size_in_bytes, infeed_buffer, shape_string.data(), - static_cast(shape_string.size())); - } - return absl::OkStatus(); -} - -static bool Infeed(xla::runtime::ExecutionContext* ctx, void** args, - void** attrs, void** rets) { - static auto* handler = CustomCall::Bind("xla.cpu.infeed") - .UserData() - .Arg() // args - .To(XlaInfeed::Handler()) - .release(); - return succeeded(Executable::Call(ctx, *handler, args, attrs, rets)); -} - -// -------------------------------------------------------------------------- // - -namespace { -struct XlaOutfeed { - absl::Status operator()(const ExecutableRunOptions* run_options, - CustomCall::RemainingArgs args, - absl::Span result_type) const; - static XlaOutfeed Handler() { return XlaOutfeed(); } -}; -} // namespace - -absl::Status XlaOutfeed::operator()( - const ExecutableRunOptions* run_options, CustomCall::RemainingArgs args, - absl::Span result_type) const { - assert(result_type.size() == args.size() && - "Result types and input args should be of the same size."); - for (unsigned i = 0; i < args.size(); ++i) { - auto memref = args.get(i); - if (!succeeded(memref)) { - return absl::InvalidArgumentError( - "Failed to get arguments as (strided) memref view"); - } - - // Restoring the sign information that was lost during convert-to-signless - // pass. This information was stashed in an attribute inside - // xla_cpu::outfeed. - memref->dtype = PrimitiveType(result_type[i]); - - auto size_in_bytes = static_cast(MemrefSize(*memref)); - std::string shape_string = ToShape(*memref).SerializeAsString(); - - void* outfeed_buffer = __xla_cpu_runtime_AcquireOutfeedBufferForPopulation( - run_options, size_in_bytes, shape_string.data(), - static_cast(shape_string.size())); - // Copy to the outfeed buffer. - std::memcpy(outfeed_buffer, memref->data, size_in_bytes); - __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation( - run_options, size_in_bytes, outfeed_buffer, shape_string.data(), - static_cast(shape_string.size())); - } - return absl::OkStatus(); -} - -static bool Outfeed(xla::runtime::ExecutionContext* ctx, void** args, - void** attrs, void** rets) { - static auto* handler = CustomCall::Bind("xla.cpu.outfeed") - .UserData() - .Arg() // args - .Attr>("result_type") - .To(XlaOutfeed::Handler()) - .release(); - return succeeded(Executable::Call(ctx, *handler, args, attrs, rets)); -} - -void PopulateXlaXfeedCall(xla::runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.cpu.infeed", &xla::cpu::Infeed); - registry.Register("xla.cpu.outfeed", &xla::cpu::Outfeed); -} - -} // namespace cpu -} // namespace xla diff --git a/third_party/xla/xla/service/cpu/runtime/xfeed.h b/third_party/xla/xla/service/cpu/runtime/xfeed.h deleted file mode 100644 index abdb7f117edc76..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/xfeed.h +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2022 The OpenXLA Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef XLA_SERVICE_CPU_RUNTIME_XFEED_H_ -#define XLA_SERVICE_CPU_RUNTIME_XFEED_H_ - -#include "xla/runtime/custom_call_registry.h" - -namespace xla { -namespace cpu { - -// Populate custom call implementing XLA CPU infeed and outfeed. -void PopulateXlaXfeedCall(runtime::DirectCustomCallRegistry& registry); - -} // namespace cpu -} // namespace xla - -#endif // XLA_SERVICE_CPU_RUNTIME_XFEED_H_