diff --git a/.github/scripts/env.sh b/.github/scripts/env.sh
index 4fd192c06..9cfd67477 100644
--- a/.github/scripts/env.sh
+++ b/.github/scripts/env.sh
@@ -4,6 +4,8 @@ if [ "$1" != "nightly_wheel" ];then
source /opt/intel/oneapi/compiler/latest/env/vars.sh
source /opt/intel/oneapi/umf/latest/env/vars.sh
source /opt/intel/oneapi/pti/latest/env/vars.sh
+ source /opt/intel/oneapi/ccl/latest/env/vars.sh
+ source /opt/intel/oneapi/mpi/latest/env/vars.sh
else
echo "Don't need to source DL-Essential for nightly wheel"
fi
diff --git a/.github/workflows/_linux_transformers.yml b/.github/workflows/_linux_transformers.yml
index fd099fcb6..65dde1b6d 100644
--- a/.github/workflows/_linux_transformers.yml
+++ b/.github/workflows/_linux_transformers.yml
@@ -50,6 +50,7 @@ jobs:
DisableScratchPages: ${{ inputs.driver == 'rolling' && '1' || '0' }}
python: ${{ inputs.python != '' && inputs.python || '3.10' }}
pytorch: ${{ inputs.pytorch != '' && inputs.pytorch || 'nightly' }}
+ transformers: ${{ inputs.transformers != '' && inputs.transformers || 'v4.47.0' }}
TRANSFORMERS_TEST_DEVICE_SPEC: 'spec.py'
steps:
- name: Checkout torch-xpu-ops
@@ -60,7 +61,7 @@ jobs:
uses: actions/checkout@v4
with:
repository: huggingface/transformers
- ref: ${{ inputs.transformers != '' && inputs.transformers || 'v4.47.0' }}
+ ref: ${{ env.transformers }}
path: transformers
- name: Prepare OS environment
run: |
@@ -103,13 +104,12 @@ jobs:
rm -rf reports
cp ${{ github.workspace }}/torch-xpu-ops/.github/scripts/spec.py ./
- name: Report installed versions
- id: installed
run: |
source activate huggingface_transformers_test
- echo "TORCH_BRANCH_ID=$(python -c 'import torch; print(torch.__version__)')" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}"
- echo "TORCH_COMMIT_ID=$(python -c 'import torch; print(torch.version.git_version)')" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}"
echo "pip installed packages:"
pip list | tee ${{ github.workspace }}/transformers/tests_log/pip_list.txt
+ echo "lspci gpu devices:"
+ lspci -d ::0380 | tee ${{ github.workspace }}/transformers/tests_log/lspci_0380.txt
echo "GPU render nodes:"
cat /sys/class/drm/render*/device/device | tee ${{ github.workspace }}/transformers/tests_log/device_IDs.txt
- name: Sanitry check installed packages
@@ -120,11 +120,133 @@ jobs:
pip show torch | grep Version | grep xpu
pip show torchaudio | grep Version | grep xpu
pip show torchvision | grep Version | grep xpu
- - name: Run XPU backbone
+ python -c 'import torch; exit(not torch.xpu.is_available())'
+ - name: Run -k backbone tests
run: |
source activate huggingface_transformers_test
cd transformers
- python3 -m pytest -rsf --make-reports=tests_benchmark -k backbone tests
+ python3 -m pytest -rsf --make-reports=tests_backbone -k backbone tests
+ - name: Run tests/pipelines
+ run: |
+ source activate huggingface_transformers_test
+ cd transformers
+ # Some tests are known to fail w/o clear pattern
+ # TODO: drop ||true after triage and fixes
+ python3 -m pytest -rsf --make-reports=tests_pipelines tests/pipelines || true
+ - name: Run tests/trainer
+ run: |
+ source activate huggingface_transformers_test
+ cd transformers
+ # Excluding tests due to:
+ # * Some ray tests hang, reason unknown
+ # * torch.distributed.* not yet supported by XPU
+ pattern=" \
+ not ray and \
+ not TestTrainerDistributed and \
+ not TestTrainerDistributedXPU and \
+ not TestFSDPTrainer"
+ python3 -m pytest -rsf --make-reports=tests_trainer tests/trainer -k "$pattern"
+ - name: Print results table
+ if: ${{ ! cancelled() }}
+ run: |
+ # Helper function to return number preceeding given pattern, i.e:
+ # === 25 failed, 11 warnings, 0 errors ===
+ # Call as follows:
+ # parse_stat $line "failed"
+ function parse_stat() {
+ stat=$(cat $1 | grep $2 | sed "s/.* \([0-9]*\) $2.*/\1/")
+ if [ -n "$stat" ]; then echo $stat; else echo "0"; fi
+ }
+ cd transformers
+ {
+ echo "### Results"
+ echo "| Test group | Errors | Failed | Passed | Skipped |"
+ echo "| --- | --- | --- | --- | --- |"
+ for stat in $(find reports -name stats.txt); do
+ # Each stat.txt is located in: reports/$test_group/stats.txt
+ test_group=$(echo $stat | cut -f 2 -d/)
+ # Get failed, passed, skipped, etc. counters
+ failed=$(parse_stat $stat failed)
+ passed=$(parse_stat $stat passed)
+ skipped=$(parse_stat $stat skipped)
+ warnings=$(parse_stat $stat warnings)
+ errors=$(parse_stat $stat errors)
+ echo "| $test_group | $errors | $failed | $passed | $skipped |"
+ done
+ } >> $GITHUB_STEP_SUMMARY
+ - name: Print failure lines
+ if: ${{ ! cancelled() }}
+ run: |
+ cd transformers
+ {
+ echo "### Failure lines"
+ echo "| File | Error | Comment |"
+ echo "| --- | --- | --- |"
+ rm -rf _failures.txt
+ for failure in $(find reports -name failures_line.txt); do
+ tail -n +2 $failure >> _failures.txt
+ done
+ # failures_line.txt file does not have test case information,
+ # so we can just sort the output and report uniq values
+ sort _failures.txt | uniq > _failures_uniq.txt
+ while read line; do
+ file=$(echo $line | cut -f1 -d" " | sed "s/\(.*\):$/\1/")
+ error=$(echo $line | cut -f2 -d" " | sed "s/\(.*\):$/\1/")
+ # Failure comments often contain special characters which complicate
+ # parsing failure lines. But fortunately we know for sure where comments
+ # start. So we just output all contents starting from this position and
+ # wrap everything in
to avoid collisions with Markdown formatting.
+ comment="$(echo $line | cut -f3- -d' ' | sed 's/\(.*\):$/\1/')
"
+ echo "| $file | $error | $comment |"
+ done <_failures_uniq.txt
+ } >> $GITHUB_STEP_SUMMARY
+ - name: Print annotations
+ if: ${{ ! cancelled() }}
+ run: |
+ source activate huggingface_transformers_test
+ {
+ echo "### Annotations"
+ echo "| | |"
+ echo "| --- | --- |"
+ echo "| jobs.$GITHUB_JOB.versions.os | $(source /etc/os-release && echo $VERSION_ID) |"
+ echo "| jobs.$GITHUB_JOB.versions.linux-kernel | $(uname -r) |"
+ echo "| jobs.$GITHUB_JOB.versions.python | $(python --version | cut -f2 -d' ') |"
+ packages=" \
+ level-zero \
+ libigc1 \
+ libigc2 \
+ libze1 \
+ libze-intel-gpu1 \
+ intel-i915-dkms \
+ intel-level-zero-gpu \
+ intel-opencl-icd"
+ for package in $packages; do
+ package_version=$(dpkg -l | grep $package | grep ii | head -1 | sed "s/ */ /g" | cut -f3 -d" ")
+ echo "| jobs.$GITHUB_JOB.versions.$package | $package_version |"
+ done
+ packages="accelerate \
+ numpy \
+ torch \
+ torchaudio \
+ torchvision \
+ transformers"
+ for package in $packages; do
+ package_version=$(python -c "import $package; print($package.__version__)" || true)
+ echo "| jobs.$GITHUB_JOB.versions.$package | $package_version |"
+ done
+ # printing annotations for GPU cards
+ var="[$(cat /sys/class/drm/render*/device/vendor || true)]"
+ echo "| jobs.$GITHUB_JOB.drm.render_nodes_vendor_ids | $(echo $var | sed 's/ /,/g') |"
+ var="[$(cat /sys/class/drm/render*/device/device || true)]"
+ echo "| jobs.$GITHUB_JOB.drm.render_nodes_device_ids | $(echo $var | sed 's/ /,/g') |"
+ var=$(python -c "import torch; print(torch.version.xpu)" || true)
+ echo "| jobs.$GITHUB_JOB.torch.version.xpu | $var |"
+ var=$(python -c "import torch; print(torch.xpu.device_count())" || true)
+ echo "| jobs.$GITHUB_JOB.torch.xpu.device_count | $var |"
+ # printing annotations with key environment variables
+ echo "| jobs.$GITHUB_JOB.env.ZE_AFFINITY_MASK | $ZE_AFFINITY_MASK |"
+ echo "| jobs.$GITHUB_JOB.env.NEOReadDebugKeys | $NEOReadDebugKeys |"
+ } >> $GITHUB_STEP_SUMMARY
- name: Upload Test log
if: ${{ ! cancelled() }}
uses: actions/upload-artifact@v4
diff --git a/src/ATen/native/xpu/RNN.cpp b/src/ATen/native/xpu/RNN.cpp
new file mode 100644
index 000000000..74152f293
--- /dev/null
+++ b/src/ATen/native/xpu/RNN.cpp
@@ -0,0 +1,46 @@
+#include
+#include
+
+namespace at::native {
+
+std::tuple _thnn_fused_lstm_cell_xpu(
+ const Tensor& input_gates,
+ const Tensor& hidden_gates,
+ const Tensor& cx,
+ const std::optional& input_bias_opt,
+ const std::optional& hidden_bias_opt) {
+ return native::xpu::_thnn_fused_lstm_cell_kernel(
+ input_gates, hidden_gates, cx, input_bias_opt, hidden_bias_opt);
+}
+
+std::tuple _thnn_fused_lstm_cell_backward_xpu(
+ const std::optional& grad_hy_opt,
+ const std::optional& grad_cy_opt,
+ const Tensor& cx,
+ const Tensor& cy,
+ const Tensor& workspace,
+ bool has_bias) {
+ return native::xpu::_thnn_fused_lstm_cell_backward_kernel(
+ grad_hy_opt, grad_cy_opt, cx, cy, workspace, has_bias);
+}
+
+std::tuple _thnn_fused_gru_cell_xpu(
+ const Tensor& input_gates,
+ const Tensor& hidden_gates,
+ const Tensor& hx,
+ const std::optional& input_bias,
+ const std::optional& hidden_bias) {
+ return native::xpu::_thnn_fused_gru_cell_kernel(
+ input_gates, hidden_gates, hx, input_bias, hidden_bias);
+}
+
+std::tuple
+_thnn_fused_gru_cell_backward_xpu(
+ const Tensor& grad_hy,
+ const Tensor& workspace,
+ bool has_bias) {
+ return native::xpu::_thnn_fused_gru_cell_backward_kernel(
+ grad_hy, workspace, has_bias);
+}
+
+} // namespace at::native
diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template
index 1df3cd072..72f2aacdd 100644
--- a/src/ATen/native/xpu/XPUFallback.template
+++ b/src/ATen/native/xpu/XPUFallback.template
@@ -185,7 +185,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"lu_unpack.out",
"ormqr",
"_scaled_mm",
- "_thnn_fused_gru_cell",
"_to_sparse_csr",
"triangular_solve.X",
"_validate_compressed_sparse_indices",
diff --git a/src/ATen/native/xpu/sycl/RNNKernels.cpp b/src/ATen/native/xpu/sycl/RNNKernels.cpp
new file mode 100644
index 000000000..bad6bdf69
--- /dev/null
+++ b/src/ATen/native/xpu/sycl/RNNKernels.cpp
@@ -0,0 +1,968 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+#include
+
+namespace at::native::xpu {
+
+using at::native::canUse32BitIndexMath;
+using at::xpu::detail::getTensorInfo;
+using at::xpu::detail::IndexToOffset;
+using at::xpu::detail::TensorInfo;
+
+std::tuple rnn_get_launch_config(
+ int64_t max_threads_per_group,
+ int64_t numel) {
+ int64_t num_groups =
+ (numel + max_threads_per_group - 1) / max_threads_per_group;
+ auto hw_max_groups = syclMaxWorkItemsPerTile() / max_threads_per_group;
+ num_groups = num_groups > hw_max_groups ? hw_max_groups : num_groups;
+ return std::make_tuple(num_groups, max_threads_per_group);
+}
+
+// Factor will be 3 for GRU and 4 for LSTM
+void checkSizes(
+ CheckedFrom c,
+ const TensorArg& input_gates,
+ const TensorArg& hidden_gates,
+ const TensorArg& input_bias,
+ const TensorArg& hidden_bias,
+ int64_t factor,
+ const TensorArg& prev_hidden) {
+ checkDim(c, input_gates, 2);
+ checkSameSize(c, input_gates, hidden_gates);
+ int64_t gates_size = input_gates->size(1);
+
+ if (input_bias->defined()) {
+ checkDim(c, input_bias, 1);
+ checkNumel(c, input_bias, gates_size);
+ checkSameSize(c, input_bias, hidden_bias);
+ }
+
+ checkDim(c, prev_hidden, 2);
+ checkNumel(c, prev_hidden, input_gates->size(0) * gates_size / factor);
+
+ checkAllSameGPU(
+ c, {input_gates, hidden_gates, input_bias, hidden_bias, prev_hidden});
+}
+
+bool allContiguous(at::TensorList tensors) {
+ return std::all_of(tensors.begin(), tensors.end(), [](const at::Tensor& t) {
+ return !t.defined() || t.is_contiguous();
+ });
+}
+
+template
+TensorInfo tryGetTensorInfo(const at::Tensor& t) {
+ return t.defined() ? getTensorInfo(t) : TensorInfo{};
+}
+
+void collapseDims(){};
+template
+void collapseDims(TensorInfo& info, Args&... infos) {
+ info.collapseDims();
+ collapseDims(infos...);
+}
+
+#define DEVICE_LINEAR_GET(D_TENSOR, INDEX) \
+ D_TENSOR.data[IndexToOffset::get(INDEX, D_TENSOR)]
+
+// Biases are always 1D
+#define DEVICE_BIAS_GET(D_TENSOR, INDEX) \
+ D_TENSOR.data[IndexToOffset::get(INDEX, D_TENSOR)]
+
+#define H2F(input) static_cast(input)
+#define F2H(input) static_cast(input)
+
+template
+inline T sigmoid(T in) {
+ T one = static_cast(1.0);
+ return one / (one + std::exp(-in));
+}
+
+template
+struct LstmCellForwardFunctor {
+ void operator()(sycl::nd_item<1> item) const {
+ bool has_bias = bias1_.data != nullptr;
+
+ for (index_type linearIndex = item.get_global_id(0);
+ linearIndex < totalElements_;
+ linearIndex += item.get_group_range(0) * item.get_local_range(0)) {
+ index_type offset = (linearIndex / hsz_) * 4 * hsz_ + linearIndex % hsz_;
+
+ scalar_t iig = DEVICE_LINEAR_GET(input_, offset + 0 * hsz_);
+ scalar_t ifg = DEVICE_LINEAR_GET(input_, offset + 1 * hsz_);
+ scalar_t icg = DEVICE_LINEAR_GET(input_, offset + 2 * hsz_);
+ scalar_t iog = DEVICE_LINEAR_GET(input_, offset + 3 * hsz_);
+
+ scalar_t hig = DEVICE_LINEAR_GET(hidden_, offset + 0 * hsz_);
+ scalar_t hfg = DEVICE_LINEAR_GET(hidden_, offset + 1 * hsz_);
+ scalar_t hcg = DEVICE_LINEAR_GET(hidden_, offset + 2 * hsz_);
+ scalar_t hog = DEVICE_LINEAR_GET(hidden_, offset + 3 * hsz_);
+
+ scalar_t* wig = &DEVICE_LINEAR_GET(workspace_, offset + 0 * hsz_);
+ scalar_t* wfg = &DEVICE_LINEAR_GET(workspace_, offset + 1 * hsz_);
+ scalar_t* wcg = &DEVICE_LINEAR_GET(workspace_, offset + 2 * hsz_);
+ scalar_t* wog = &DEVICE_LINEAR_GET(workspace_, offset + 3 * hsz_);
+
+ scalar_t cx = DEVICE_LINEAR_GET(_cx_, linearIndex);
+
+ scalar_t* hy = &DEVICE_LINEAR_GET(_hy_, linearIndex);
+ scalar_t* cy = &DEVICE_LINEAR_GET(_cy_, linearIndex);
+
+ scalar_t b1i, b1f, b1c, b1o;
+ scalar_t b2i, b2f, b2c, b2o;
+
+ if (has_bias) {
+ b1i = DEVICE_BIAS_GET(bias1_, linearIndex % hsz_ + 0 * hsz_);
+ b1f = DEVICE_BIAS_GET(bias1_, linearIndex % hsz_ + 1 * hsz_);
+ b1c = DEVICE_BIAS_GET(bias1_, linearIndex % hsz_ + 2 * hsz_);
+ b1o = DEVICE_BIAS_GET(bias1_, linearIndex % hsz_ + 3 * hsz_);
+
+ b2i = DEVICE_BIAS_GET(bias2_, linearIndex % hsz_ + 0 * hsz_);
+ b2f = DEVICE_BIAS_GET(bias2_, linearIndex % hsz_ + 1 * hsz_);
+ b2c = DEVICE_BIAS_GET(bias2_, linearIndex % hsz_ + 2 * hsz_);
+ b2o = DEVICE_BIAS_GET(bias2_, linearIndex % hsz_ + 3 * hsz_);
+ } else {
+ b1i = F2H(0.0);
+ b1f = F2H(0.0);
+ b1c = F2H(0.0);
+ b1o = F2H(0.0);
+ b2i = F2H(0.0);
+ b2f = F2H(0.0);
+ b2c = F2H(0.0);
+ b2o = F2H(0.0);
+ }
+
+ accscalar_t ig, fg, cg, og;
+ accscalar_t f_hy, f_cy;
+
+ ig = sigmoid(H2F(iig) + H2F(hig) + H2F(b1i) + H2F(b2i));
+ fg = sigmoid(H2F(ifg) + H2F(hfg) + H2F(b1f) + H2F(b2f));
+ cg = std::tanh(H2F(icg) + H2F(hcg) + H2F(b1c) + H2F(b2c));
+ og = sigmoid(H2F(iog) + H2F(hog) + H2F(b1o) + H2F(b2o));
+
+ f_cy = (fg * H2F(cx)) + (ig * cg);
+ f_hy = og * std::tanh(f_cy);
+
+ *hy = F2H(f_hy);
+ *cy = F2H(f_cy);
+
+ // SAVE FOR BACKWARDS
+ // Also need cy and cx but can be saved easily in python
+ *wig = F2H(ig);
+ *wfg = F2H(fg);
+ *wcg = F2H(cg);
+ *wog = F2H(og);
+ }
+ }
+
+ LstmCellForwardFunctor(
+ TensorInfo input,
+ TensorInfo hidden,
+ TensorInfo bias1,
+ TensorInfo bias2,
+ TensorInfo _cx,
+ TensorInfo _hy,
+ TensorInfo _cy,
+ TensorInfo workspace,
+ index_type hsz,
+ index_type totalElements)
+ : input_(input),
+ hidden_(hidden),
+ bias1_(bias1),
+ bias2_(bias2),
+ _cx_(_cx),
+ _hy_(_hy),
+ _cy_(_cy),
+ workspace_(workspace),
+ hsz_(hsz),
+ totalElements_(totalElements) {}
+
+ private:
+ TensorInfo input_;
+ TensorInfo hidden_;
+ TensorInfo bias1_;
+ TensorInfo bias2_;
+ TensorInfo _cx_;
+ TensorInfo _hy_;
+ TensorInfo _cy_;
+ TensorInfo workspace_;
+ index_type hsz_;
+ index_type totalElements_;
+};
+
+template
+struct LstmCellBackwardFunctor {
+ void operator()(sycl::nd_item<1> item) const {
+ bool has_gradoutput = gradoutput_.data != nullptr;
+ bool has_gradoutputcell = gradoutputcell_.data != nullptr;
+
+ for (index_type linearIndex = item.get_global_id(0);
+ linearIndex < totalElements_;
+ linearIndex += item.get_group_range(0) * item.get_local_range(0)) {
+ index_type offset = (linearIndex / hsz_) * 4 * hsz_ + linearIndex % hsz_;
+
+ scalar_t ig = DEVICE_LINEAR_GET(storage_, offset + 0 * hsz_);
+ scalar_t fg = DEVICE_LINEAR_GET(storage_, offset + 1 * hsz_);
+ scalar_t cg = DEVICE_LINEAR_GET(storage_, offset + 2 * hsz_);
+ scalar_t og = DEVICE_LINEAR_GET(storage_, offset + 3 * hsz_);
+
+ scalar_t* ih = &DEVICE_LINEAR_GET(gradInGates_, offset + 0 * hsz_);
+ scalar_t* fh = &DEVICE_LINEAR_GET(gradInGates_, offset + 1 * hsz_);
+ scalar_t* ch = &DEVICE_LINEAR_GET(gradInGates_, offset + 2 * hsz_);
+ scalar_t* oh = &DEVICE_LINEAR_GET(gradInGates_, offset + 3 * hsz_);
+
+ // will return hidden grads here
+ scalar_t cx = DEVICE_LINEAR_GET(_cx_, linearIndex);
+ scalar_t cy = DEVICE_LINEAR_GET(_cy_, linearIndex);
+
+ scalar_t* gi = &DEVICE_LINEAR_GET(gradInputCx_, linearIndex);
+
+ accscalar_t go = has_gradoutput
+ ? H2F(DEVICE_LINEAR_GET(gradoutput_, linearIndex))
+ : 0.f;
+ accscalar_t goc = has_gradoutputcell
+ ? H2F(DEVICE_LINEAR_GET(gradoutputcell_, linearIndex))
+ : 0.f;
+
+ accscalar_t gcx = std::tanh(H2F(cy));
+
+ accscalar_t gog = go * gcx;
+ gcx = go * H2F(og) * (1 - gcx * gcx) + goc;
+
+ accscalar_t gig = gcx * H2F(cg);
+ accscalar_t gfg = gcx * H2F(cx);
+ accscalar_t gcg = gcx * H2F(ig);
+
+ gcx = gcx * H2F(fg);
+
+ gig = gig * (1 - H2F(ig)) * H2F(ig);
+ gfg = gfg * (1 - H2F(fg)) * H2F(fg);
+ gcg = gcg * (1 - H2F(cg) * H2F(cg));
+ gog = gog * (1 - H2F(og)) * H2F(og);
+
+ *ih = F2H(gig);
+ *fh = F2H(gfg);
+ *ch = F2H(gcg);
+ *oh = F2H(gog);
+
+ *gi = F2H(gcx);
+ }
+ }
+
+ LstmCellBackwardFunctor(
+ TensorInfo storage,
+ TensorInfo gradInGates,
+ TensorInfo _cx,
+ TensorInfo _cy,
+ TensorInfo gradoutput,
+ TensorInfo gradoutputcell,
+ TensorInfo gradInputCx,
+ index_type hsz,
+ index_type totalElements)
+ : storage_(storage),
+ gradInGates_(gradInGates),
+ _cx_(_cx),
+ _cy_(_cy),
+ gradoutput_(gradoutput),
+ gradoutputcell_(gradoutputcell),
+ gradInputCx_(gradInputCx),
+ hsz_(hsz),
+ totalElements_(totalElements) {}
+
+ private:
+ TensorInfo storage_;
+ TensorInfo gradInGates_;
+ TensorInfo _cx_;
+ TensorInfo _cy_;
+ TensorInfo gradoutput_;
+ TensorInfo gradoutputcell_;
+ TensorInfo gradInputCx_;
+ index_type hsz_;
+ index_type totalElements_;
+};
+
+template
+struct GruCellForwardFunctor {
+ void operator()(sycl::nd_item<1> item) const {
+ bool has_bias = Bias1_.data != nullptr;
+
+ for (index_type linearIndex = item.get_global_id(0);
+ linearIndex < totalElements_;
+ linearIndex += item.get_group_range(0) * item.get_local_range(0)) {
+ index_type offset = (linearIndex / hsz_) * 3 * hsz_ + linearIndex % hsz_;
+
+ scalar_t ir = DEVICE_LINEAR_GET(Input_, offset + 0 * hsz_);
+ scalar_t ii = DEVICE_LINEAR_GET(Input_, offset + 1 * hsz_);
+ scalar_t in = DEVICE_LINEAR_GET(Input_, offset + 2 * hsz_);
+ scalar_t hr = DEVICE_LINEAR_GET(Hidden_, offset + 0 * hsz_);
+ scalar_t hi = DEVICE_LINEAR_GET(Hidden_, offset + 1 * hsz_);
+ scalar_t hn = DEVICE_LINEAR_GET(Hidden_, offset + 2 * hsz_);
+
+ scalar_t hx = DEVICE_LINEAR_GET(_hx_, linearIndex);
+ scalar_t* hy = &DEVICE_LINEAR_GET(_hy_, linearIndex);
+
+ scalar_t b1r, b1i, b1n, b2r, b2i, b2n;
+
+ if (has_bias) {
+ b1r = DEVICE_BIAS_GET(Bias1_, linearIndex % hsz_ + 0 * hsz_);
+ b1i = DEVICE_BIAS_GET(Bias1_, linearIndex % hsz_ + 1 * hsz_);
+ b1n = DEVICE_BIAS_GET(Bias1_, linearIndex % hsz_ + 2 * hsz_);
+
+ b2r = DEVICE_BIAS_GET(Bias2_, linearIndex % hsz_ + 0 * hsz_);
+ b2i = DEVICE_BIAS_GET(Bias2_, linearIndex % hsz_ + 1 * hsz_);
+ b2n = DEVICE_BIAS_GET(Bias2_, linearIndex % hsz_ + 2 * hsz_);
+ } else {
+ b1r = F2H(0.0);
+ b1i = F2H(0.0);
+ b1n = F2H(0.0);
+ b2r = F2H(0.0);
+ b2i = F2H(0.0);
+ b2n = F2H(0.0);
+ }
+
+ offset = (linearIndex / hsz_) * 5 * hsz_ + linearIndex % hsz_;
+
+ accscalar_t rg, ig, ng;
+
+ rg = sigmoid(H2F(ir) + H2F(hr) + H2F(b1r) + H2F(b2r));
+ ig = sigmoid(H2F(ii) + H2F(hi) + H2F(b1i) + H2F(b2i));
+
+ ng = H2F(in) + H2F(b1n) + rg * (H2F(hn) + H2F(b2n));
+ ng = std::tanh(ng);
+ *hy = F2H(ng + ig * (H2F(hx) - ng));
+
+ // SAVE FOR BACKWARDS
+ DEVICE_LINEAR_GET(storage_, offset + 0 * hsz_) = F2H(rg);
+ DEVICE_LINEAR_GET(storage_, offset + 1 * hsz_) = F2H(ig);
+ DEVICE_LINEAR_GET(storage_, offset + 2 * hsz_) = F2H(ng);
+ DEVICE_LINEAR_GET(storage_, offset + 3 * hsz_) = hx;
+ DEVICE_LINEAR_GET(storage_, offset + 4 * hsz_) = F2H(H2F(hn) + H2F(b2n));
+ }
+ }
+
+ GruCellForwardFunctor(
+ TensorInfo Input,
+ const TensorInfo Hidden,
+ const TensorInfo Bias1,
+ const TensorInfo Bias2,
+ const TensorInfo _hx,
+ const TensorInfo _hy,
+ const TensorInfo storage,
+ const index_type hsz,
+ const index_type totalElements)
+ : Input_(Input),
+ Hidden_(Hidden),
+ Bias1_(Bias1),
+ Bias2_(Bias2),
+ _hx_(_hx),
+ _hy_(_hy),
+ storage_(storage),
+ hsz_(hsz),
+ totalElements_(totalElements) {}
+
+ private:
+ TensorInfo Input_;
+ const TensorInfo Hidden_;
+ const TensorInfo Bias1_;
+ const TensorInfo Bias2_;
+ const TensorInfo _hx_;
+ const TensorInfo _hy_;
+ const TensorInfo storage_;
+ const index_type hsz_;
+ const index_type totalElements_;
+};
+
+template
+struct GruCellBackwardFunctor {
+ void operator()(sycl::nd_item<1> item) const {
+ for (index_type linearIndex = item.get_global_id(0);
+ linearIndex < totalElements_;
+ linearIndex += item.get_group_range(0) * item.get_local_range(0)) {
+ index_type offset = (linearIndex / hsz_) * 5 * hsz_ + linearIndex % hsz_;
+
+ scalar_t rg = DEVICE_LINEAR_GET(storage_, offset + 0 * hsz_);
+ scalar_t ig = DEVICE_LINEAR_GET(storage_, offset + 1 * hsz_);
+ scalar_t ng = DEVICE_LINEAR_GET(storage_, offset + 2 * hsz_);
+ scalar_t hx = DEVICE_LINEAR_GET(storage_, offset + 3 * hsz_);
+ scalar_t hn = DEVICE_LINEAR_GET(storage_, offset + 4 * hsz_);
+
+ scalar_t go = DEVICE_LINEAR_GET(gradOutput_, linearIndex);
+
+ offset = (linearIndex / hsz_) * 3 * hsz_ + linearIndex % hsz_;
+
+ accscalar_t gig = H2F(go) * (H2F(hx) - H2F(ng)) * (1 - H2F(ig)) * H2F(ig);
+ accscalar_t ghx = H2F(go) * H2F(ig);
+ accscalar_t gin = H2F(go) * (1 - H2F(ig)) * (1 - H2F(ng) * H2F(ng));
+ accscalar_t ghn = gin * H2F(rg);
+ accscalar_t grg = gin * H2F(hn) * (1 - H2F(rg)) * H2F(rg);
+
+ DEVICE_LINEAR_GET(gradInInput_, offset + 0 * hsz_) = F2H(grg);
+ DEVICE_LINEAR_GET(gradInInput_, offset + 1 * hsz_) = F2H(gig);
+ DEVICE_LINEAR_GET(gradInInput_, offset + 2 * hsz_) = F2H(gin);
+
+ DEVICE_LINEAR_GET(gradInHidden_, offset + 0 * hsz_) = F2H(grg);
+ DEVICE_LINEAR_GET(gradInHidden_, offset + 1 * hsz_) = F2H(gig);
+ DEVICE_LINEAR_GET(gradInHidden_, offset + 2 * hsz_) = F2H(ghn);
+ DEVICE_LINEAR_GET(gradInputHx_, linearIndex) = F2H(ghx);
+ }
+ }
+
+ GruCellBackwardFunctor(
+ TensorInfo gradInInput,
+ TensorInfo gradInHidden,
+ TensorInfo gradOutput,
+ TensorInfo gradInputHx,
+ TensorInfo storage,
+ index_type hsz,
+ index_type totalElements)
+ : gradInInput_(gradInInput),
+ gradInHidden_(gradInHidden),
+ gradOutput_(gradOutput),
+ gradInputHx_(gradInputHx),
+ storage_(storage),
+ hsz_(hsz),
+ totalElements_(totalElements) {}
+
+ private:
+ TensorInfo gradInInput_;
+ TensorInfo gradInHidden_;
+ TensorInfo gradOutput_;
+ TensorInfo gradInputHx_;
+ TensorInfo storage_;
+ index_type hsz_;
+ index_type totalElements_;
+};
+
+#undef DEVICE_LINEAR_GET
+#undef DEVICE_BIAS_GET
+#undef H2F
+#undef F2H
+
+template
+void lstm_forward_impl(
+ const Tensor& input_gates,
+ const Tensor& hidden_gates,
+ const Tensor& input_bias,
+ const Tensor& hidden_bias,
+ const Tensor& cx,
+ const Tensor& hy,
+ const Tensor& cy,
+ const Tensor& workspace) {
+ using accscalar_t = at::acc_type_device;
+
+ int64_t numel = cx.numel();
+ if (numel == 0)
+ return;
+
+ using KernelT = LstmCellForwardFunctor;
+ auto max_wg_size = syclMaxWorkGroupSize();
+ auto config = rnn_get_launch_config(max_wg_size, numel);
+ auto nwg = std::get<0>(config);
+ auto local_range = std::get<1>(config);
+
+ auto input_gatesI = getTensorInfo(input_gates);
+ auto hidden_gatesI = getTensorInfo(hidden_gates);
+ auto input_biasI = tryGetTensorInfo(input_bias);
+ auto hidden_biasI = tryGetTensorInfo(hidden_bias);
+ auto cxI = getTensorInfo(cx);
+ auto hyI = getTensorInfo(hy);
+ auto cyI = getTensorInfo(cy);
+ auto workspaceI = getTensorInfo(workspace);
+ index_type hidden_size = cxI.sizes[cxI.dims - 1];
+
+ if (allContiguous(
+ {input_gates,
+ hidden_gates,
+ input_bias,
+ hidden_bias,
+ cx,
+ hy,
+ cy,
+ workspace})) {
+ collapseDims(
+ input_gatesI,
+ hidden_gatesI,
+ input_biasI,
+ hidden_biasI,
+ cxI,
+ hyI,
+ cyI,
+ workspaceI);
+ KernelT kfn(
+ input_gatesI,
+ hidden_gatesI,
+ input_biasI,
+ hidden_biasI,
+ cxI,
+ hyI,
+ cyI,
+ workspaceI,
+ hidden_size,
+ numel);
+ sycl_kernel_submit(
+ nwg * local_range, local_range, getCurrentSYCLQueue(), kfn);
+ } else {
+ KernelT kfn(
+ input_gatesI,
+ hidden_gatesI,
+ input_biasI,
+ hidden_biasI,
+ cxI,
+ hyI,
+ cyI,
+ workspaceI,
+ hidden_size,
+ numel);
+ sycl_kernel_submit(
+ nwg * local_range, local_range, getCurrentSYCLQueue(), kfn);
+ }
+}
+
+template
+void lstm_backward_impl(
+ const Tensor& grad_hy,
+ const Tensor& grad_cy,
+ const Tensor& cx,
+ const Tensor& cy,
+ const Tensor& workspace,
+ const Tensor& grad_gates,
+ const Tensor& grad_cx) {
+ using accscalar_t = at::acc_type_device;
+
+ int64_t numel = cx.numel();
+ if (numel == 0)
+ return;
+
+ using KernelT = LstmCellBackwardFunctor;
+ auto max_wg_size = syclMaxWorkGroupSize();
+ auto config = rnn_get_launch_config(max_wg_size, numel);
+ auto nwg = std::get<0>(config);
+ auto local_range = std::get<1>(config);
+
+ auto grad_hyI = tryGetTensorInfo(grad_hy);
+ auto grad_cyI = tryGetTensorInfo(grad_cy);
+ auto cxI = getTensorInfo(cx);
+ auto cyI = getTensorInfo(cy);
+ auto workspaceI = getTensorInfo(workspace);
+ auto grad_gatesI = getTensorInfo(grad_gates);
+ auto grad_cxI = getTensorInfo(grad_cx);
+ index_type hidden_size = cxI.sizes[cxI.dims - 1];
+
+ if (allContiguous(
+ {grad_hy, grad_cy, cx, cy, workspace, grad_gates, grad_cx})) {
+ collapseDims(
+ grad_hyI, grad_cyI, cxI, cyI, workspaceI, grad_gatesI, grad_cxI);
+ KernelT kfn(
+ workspaceI,
+ grad_gatesI,
+ cxI,
+ cyI,
+ grad_hyI,
+ grad_cyI,
+ grad_cxI,
+ hidden_size,
+ numel);
+ sycl_kernel_submit(
+ nwg * local_range, local_range, getCurrentSYCLQueue(), kfn);
+ } else {
+ KernelT kfn(
+ workspaceI,
+ grad_gatesI,
+ cxI,
+ cyI,
+ grad_hyI,
+ grad_cyI,
+ grad_cxI,
+ hidden_size,
+ numel);
+ sycl_kernel_submit(
+ nwg * local_range, local_range, getCurrentSYCLQueue(), kfn);
+ }
+}
+
+template
+void gru_forward_impl(
+ const Tensor& input_gates,
+ const Tensor& hidden_gates,
+ const Tensor& input_bias,
+ const Tensor& hidden_bias,
+ const Tensor& hx,
+ const Tensor& hy,
+ const Tensor& workspace) {
+ using accscalar_t = at::acc_type_device;
+
+ int64_t numel = hx.numel();
+ if (numel == 0)
+ return;
+
+ using KernelT = GruCellForwardFunctor;
+ auto max_wg_size = syclMaxWorkGroupSize();
+ auto config = rnn_get_launch_config(max_wg_size, numel);
+ auto nwg = std::get<0>(config);
+ auto local_range = std::get<1>(config);
+
+ auto input_gatesI = getTensorInfo(input_gates);
+ auto hidden_gatesI = getTensorInfo(hidden_gates);
+ auto input_biasI = tryGetTensorInfo(input_bias);
+ auto hidden_biasI = tryGetTensorInfo(hidden_bias);
+ auto hxI = getTensorInfo(hx);
+ auto hyI = getTensorInfo(hy);
+ auto workspaceI = getTensorInfo(workspace);
+ index_type hidden_size = hxI.sizes[hxI.dims - 1];
+
+ if (allContiguous(
+ {input_gates,
+ hidden_gates,
+ input_bias,
+ hidden_bias,
+ hx,
+ hy,
+ workspace})) {
+ collapseDims(
+ input_gatesI,
+ hidden_gatesI,
+ input_biasI,
+ hidden_biasI,
+ hxI,
+ hyI,
+ workspaceI);
+ KernelT kfn(
+ input_gatesI,
+ hidden_gatesI,
+ input_biasI,
+ hidden_biasI,
+ hxI,
+ hyI,
+ workspaceI,
+ hidden_size,
+ numel);
+ sycl_kernel_submit(
+ nwg * local_range, local_range, getCurrentSYCLQueue(), kfn);
+ } else {
+ KernelT kfn(
+ input_gatesI,
+ hidden_gatesI,
+ input_biasI,
+ hidden_biasI,
+ hxI,
+ hyI,
+ workspaceI,
+ hidden_size,
+ numel);
+ sycl_kernel_submit(
+ nwg * local_range, local_range, getCurrentSYCLQueue(), kfn);
+ }
+}
+
+template
+void gru_backward_impl(
+ const Tensor& grad_hy,
+ const Tensor& workspace,
+ const Tensor& grad_input_gates,
+ const Tensor& grad_hidden_gates,
+ const Tensor& grad_hx) {
+ using accscalar_t = at::acc_type_device;
+
+ int64_t numel = grad_hy.numel();
+ if (numel == 0)
+ return;
+
+ using KernelT = GruCellBackwardFunctor;
+ auto max_wg_size = syclMaxWorkGroupSize();
+ auto config = rnn_get_launch_config(max_wg_size, numel);
+ auto nwg = std::get<0>(config);
+ auto local_range = std::get<1>(config);
+
+ auto grad_hyI = getTensorInfo(grad_hy);
+ auto workspaceI = getTensorInfo(workspace);
+ auto grad_input_gatesI =
+ getTensorInfo(grad_input_gates);
+ auto grad_hidden_gatesI =
+ getTensorInfo(grad_hidden_gates);
+ auto grad_hxI = getTensorInfo(grad_hx);
+ index_type hidden_size = grad_hyI.sizes[grad_hyI.dims - 1];
+
+ if (allContiguous(
+ {grad_hy, workspace, grad_input_gates, grad_hidden_gates, grad_hx})) {
+ collapseDims(
+ grad_hyI, workspaceI, grad_input_gatesI, grad_hidden_gatesI, grad_hxI);
+ KernelT kfn(
+ grad_input_gatesI,
+ grad_hidden_gatesI,
+ grad_hyI,
+ grad_hxI,
+ workspaceI,
+ hidden_size,
+ numel);
+ sycl_kernel_submit(
+ nwg * local_range, local_range, getCurrentSYCLQueue(), kfn);
+ } else {
+ KernelT kfn(
+ grad_input_gatesI,
+ grad_hidden_gatesI,
+ grad_hyI,
+ grad_hxI,
+ workspaceI,
+ hidden_size,
+ numel);
+ sycl_kernel_submit(
+ nwg * local_range, local_range, getCurrentSYCLQueue(), kfn);
+ }
+}
+
+// Note [64-bit index math check elision]
+// It's enough to perform the check for 64-bit math on the largest tensor only.
+// If 32-bit is enough for it, it will suffice for all other tensors too, and we
+// can save some work using this trick.
+
+std::tuple _thnn_fused_lstm_cell_kernel(
+ const Tensor& input_gates,
+ const Tensor& hidden_gates,
+ const Tensor& cx,
+ const std::optional& input_bias_opt,
+ const std::optional& hidden_bias_opt) {
+ // See [Note: hacky wrapper removal for optional tensor]
+ c10::MaybeOwned input_bias_maybe_owned =
+ at::borrow_from_optional_tensor(input_bias_opt);
+ const Tensor& input_bias = *input_bias_maybe_owned;
+ const Tensor& hidden_bias = hidden_bias_opt.value_or(Tensor());
+
+ checkSizes(
+ "_thnn_fused_lstm_cell_xpu",
+ {input_gates, "input_gates", 1},
+ {hidden_gates, "hidden_gates", 2},
+ {input_bias, "input_bias", 3},
+ {hidden_bias, "hidden_bias", 4},
+ /*factor=*/4,
+ {cx, "prev_hidden", 5});
+
+ auto workspace = at::empty_like(input_gates, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+ auto hy = at::empty_like(cx, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+ auto cy = at::empty_like(cx, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+ AT_DISPATCH_FLOATING_TYPES_AND2(
+ at::ScalarType::Half,
+ at::ScalarType::BFloat16,
+ input_gates.scalar_type(),
+ "_thnn_fused_lstm_cell_xpu",
+ [&] {
+ if (canUse32BitIndexMath(
+ workspace)) { // See Note [64-bit index math check elision]
+ lstm_forward_impl(
+ input_gates,
+ hidden_gates,
+ input_bias,
+ hidden_bias,
+ cx,
+ hy,
+ cy,
+ workspace);
+ } else {
+ lstm_forward_impl(
+ input_gates,
+ hidden_gates,
+ input_bias,
+ hidden_bias,
+ cx,
+ hy,
+ cy,
+ workspace);
+ }
+ });
+ return std::make_tuple(std::move(hy), std::move(cy), std::move(workspace));
+}
+
+void checkLSTMBackwardSizes(
+ const TensorArg& grad_hy,
+ const TensorArg& grad_cy,
+ const TensorArg& cx,
+ const TensorArg& cy,
+ const TensorArg& workspace) {
+ CheckedFrom c = "fused_lstm_cell_backward";
+ const TensorArg& defined_grad = grad_hy->defined() ? grad_hy : grad_cy;
+ checkDim(c, defined_grad, 2);
+ auto exp_size = defined_grad->sizes();
+ if (grad_hy->defined()) {
+ checkSize(c, grad_hy, exp_size);
+ }
+ if (grad_cy->defined()) {
+ checkSize(c, grad_cy, exp_size);
+ }
+ checkSize(c, cx, exp_size);
+ checkSize(c, cy, exp_size);
+ checkDim(c, workspace, 2);
+ checkNumel(c, workspace, exp_size[0] * exp_size[1] * 4);
+}
+
+std::tuple _thnn_fused_lstm_cell_backward_kernel(
+ const std::optional& grad_hy_opt,
+ const std::optional& grad_cy_opt,
+ const Tensor& cx,
+ const Tensor& cy,
+ const Tensor& workspace,
+ bool has_bias) {
+ // See [Note: hacky wrapper removal for optional tensor]
+ c10::MaybeOwned grad_hy_maybe_owned =
+ at::borrow_from_optional_tensor(grad_hy_opt);
+ const Tensor& grad_hy = *grad_hy_maybe_owned;
+ const Tensor& grad_cy = grad_cy_opt.value_or(Tensor());
+
+ if (!grad_hy.defined() && !grad_cy.defined()) {
+ return std::tuple();
+ }
+ checkLSTMBackwardSizes(
+ {grad_hy, "grad_hy", 1},
+ {grad_cy, "grad_cy", 2},
+ {cx, "cx", 3},
+ {cy, "cy", 4},
+ {workspace, "workspace", 5});
+
+ auto grad_gates = at::empty_like(workspace, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+ auto grad_cx = at::empty_like(cx, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+ AT_DISPATCH_FLOATING_TYPES_AND2(
+ at::ScalarType::Half,
+ at::ScalarType::BFloat16,
+ workspace.scalar_type(),
+ "_thnn_fused_lstm_cell_backward_xpu",
+ [&] {
+ if (canUse32BitIndexMath(
+ workspace)) { // See Note [64-bit index math check elision]
+ lstm_backward_impl(
+ grad_hy, grad_cy, cx, cy, workspace, grad_gates, grad_cx);
+ } else {
+ lstm_backward_impl(
+ grad_hy, grad_cy, cx, cy, workspace, grad_gates, grad_cx);
+ }
+ });
+
+ auto grad_bias =
+ has_bias ? grad_gates.sum(0, /*keepdim=*/false) : at::Tensor{};
+ return std::make_tuple(
+ std::move(grad_gates), std::move(grad_cx), std::move(grad_bias));
+}
+
+static constexpr int64_t GRU_WORKSPACE_MULTIPLIER = 5;
+
+std::tuple _thnn_fused_gru_cell_kernel(
+ const Tensor& input_gates,
+ const Tensor& hidden_gates,
+ const Tensor& hx,
+ const std::optional& input_bias_opt,
+ const std::optional& hidden_bias_opt) {
+ // See [Note: hacky wrapper removal for optional tensor]
+ c10::MaybeOwned input_bias_maybe_owned =
+ at::borrow_from_optional_tensor(input_bias_opt);
+ const Tensor& input_bias = *input_bias_maybe_owned;
+ const Tensor& hidden_bias = hidden_bias_opt.value_or(Tensor());
+
+ checkSizes(
+ "_thnn_fused_gru_cell_xpu",
+ {input_gates, "input_gates", 1},
+ {hidden_gates, "hidden_gates", 2},
+ {input_bias, "input_bias", 3},
+ {hidden_bias, "hidden_bias", 4},
+ /*factor=*/3,
+ {hx, "prev_hidden", 5});
+
+ auto workspace = at::empty(
+ {hx.size(0), hx.size(1) * GRU_WORKSPACE_MULTIPLIER}, hx.options());
+ auto hy = at::empty_like(hx, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+ AT_DISPATCH_FLOATING_TYPES_AND2(
+ at::ScalarType::Half,
+ at::ScalarType::BFloat16,
+ input_gates.scalar_type(),
+ "_thnn_fused_gru_cell_xpu",
+ [&] {
+ if (canUse32BitIndexMath(
+ workspace)) { // See Note [64-bit index math check elision]
+ gru_forward_impl(
+ input_gates,
+ hidden_gates,
+ input_bias,
+ hidden_bias,
+ hx,
+ hy,
+ workspace);
+ } else {
+ gru_forward_impl(
+ input_gates,
+ hidden_gates,
+ input_bias,
+ hidden_bias,
+ hx,
+ hy,
+ workspace);
+ }
+ });
+ return std::make_tuple(std::move(hy), std::move(workspace));
+}
+
+void checkGRUBackwardSizes(
+ const TensorArg& grad_hy,
+ const TensorArg& workspace) {
+ CheckedFrom c = "fused_gru_cell_backward";
+ checkDim(c, grad_hy, 2);
+ checkSize(
+ c,
+ workspace,
+ {grad_hy->size(0), grad_hy->size(1) * GRU_WORKSPACE_MULTIPLIER});
+}
+
+std::tuple
+_thnn_fused_gru_cell_backward_kernel(
+ const Tensor& grad_hy,
+ const Tensor& workspace,
+ bool has_bias) {
+ checkGRUBackwardSizes({grad_hy, "grad_hy", 1}, {workspace, "workspace", 2});
+
+ int64_t hidden_size = workspace.size(1) / GRU_WORKSPACE_MULTIPLIER;
+ auto grad_input_gates =
+ at::empty({workspace.size(0), hidden_size * 3}, workspace.options());
+ auto grad_hidden_gates =
+ at::empty({workspace.size(0), hidden_size * 3}, workspace.options());
+ auto grad_hx = at::empty_like(grad_hy, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+ AT_DISPATCH_FLOATING_TYPES_AND2(
+ at::ScalarType::Half,
+ at::ScalarType::BFloat16,
+ grad_hy.scalar_type(),
+ "_thnn_fused_gru_cell_backward_xpu",
+ [&] {
+ if (canUse32BitIndexMath(
+ workspace)) { // See Note [64-bit index math check elision]
+ gru_backward_impl(
+ grad_hy, workspace, grad_input_gates, grad_hidden_gates, grad_hx);
+ } else {
+ gru_backward_impl(
+ grad_hy, workspace, grad_input_gates, grad_hidden_gates, grad_hx);
+ }
+ });
+
+ at::Tensor grad_input_bias, grad_hidden_bias;
+ if (has_bias) {
+ grad_input_bias = grad_input_gates.sum(0, /*keepdim=*/false);
+ grad_hidden_bias = grad_hidden_gates.sum(0, /*keepdim=*/false);
+ }
+
+ return std::make_tuple(
+ std::move(grad_input_gates),
+ std::move(grad_hidden_gates),
+ std::move(grad_hx),
+ std::move(grad_input_bias),
+ std::move(grad_hidden_bias));
+}
+
+} // namespace at::native::xpu
diff --git a/src/ATen/native/xpu/sycl/RNNKernels.h b/src/ATen/native/xpu/sycl/RNNKernels.h
new file mode 100644
index 000000000..07f0e3f78
--- /dev/null
+++ b/src/ATen/native/xpu/sycl/RNNKernels.h
@@ -0,0 +1,36 @@
+#pragma once
+
+#include
+
+namespace at::native::xpu {
+
+TORCH_XPU_API std::tuple _thnn_fused_lstm_cell_kernel(
+ const Tensor& input_gates,
+ const Tensor& hidden_gates,
+ const Tensor& cx,
+ const std::optional& input_bias_opt,
+ const std::optional& hidden_bias_opt);
+
+TORCH_XPU_API std::tuple
+_thnn_fused_lstm_cell_backward_kernel(
+ const std::optional& grad_hy_opt,
+ const std::optional& grad_cy_opt,
+ const Tensor& cx,
+ const Tensor& cy,
+ const Tensor& workspace,
+ bool has_bias);
+
+TORCH_XPU_API std::tuple _thnn_fused_gru_cell_kernel(
+ const Tensor& input_gates,
+ const Tensor& hidden_gates,
+ const Tensor& hx,
+ const std::optional& input_bias_opt,
+ const std::optional& hidden_bias_opt);
+
+TORCH_XPU_API std::tuple
+_thnn_fused_gru_cell_backward_kernel(
+ const Tensor& grad_hy,
+ const Tensor& workspace,
+ bool has_bias);
+
+} // namespace at::native::xpu
diff --git a/src/ATen/native/xpu/sycl/ResizeKernel.cpp b/src/ATen/native/xpu/sycl/ResizeKernel.cpp
index 237a1c213..f1ee7f944 100644
--- a/src/ATen/native/xpu/sycl/ResizeKernel.cpp
+++ b/src/ATen/native/xpu/sycl/ResizeKernel.cpp
@@ -25,8 +25,9 @@ void resize_bytes_xpu(StorageImpl* storage, size_t size_bytes) {
c10::xpu::XPUGuard guard(device.index());
at::DataPtr data = allocator->allocate(size_bytes);
if (storage->data_ptr()) {
- auto q = at::xpu::getCurrentSYCLQueue();
+ at::globalContext().lazyInitDevice(c10::DeviceType::XPU);
+ auto q = at::xpu::getCurrentSYCLQueue();
q.memcpy(
data.get(), storage->data(), std::min(storage->nbytes(), size_bytes));
}
diff --git a/src/ATen/xpu/EmptyTensor.cpp b/src/ATen/xpu/EmptyTensor.cpp
index 3f5e998f8..6411bb221 100644
--- a/src/ATen/xpu/EmptyTensor.cpp
+++ b/src/ATen/xpu/EmptyTensor.cpp
@@ -54,6 +54,7 @@ TensorBase empty_strided_xpu(
IntArrayRef stride,
ScalarType dtype,
c10::optional device_opt) {
+ at::globalContext().lazyInitDevice(c10::DeviceType::XPU);
const auto device = device_or_default(device_opt);
TORCH_INTERNAL_ASSERT(device.is_xpu());
const c10::DeviceGuard device_guard(device);
diff --git a/test/xpu/extended/skip_list_arc.py b/test/xpu/extended/skip_list_arc.py
index e1e701b84..c8e26ccf3 100644
--- a/test/xpu/extended/skip_list_arc.py
+++ b/test/xpu/extended/skip_list_arc.py
@@ -7,5 +7,21 @@
"test_compare_cpu_bincount_xpu_int64",
"test_compare_cpu_bincount_xpu_int8",
"test_compare_cpu_bincount_xpu_uint8",
+ # RuntimeError: Kernel is incompatible with all devices in devs
+ # https://github.com/intel/torch-xpu-ops/issues/1150
+ "test_compare_cpu_logcumsumexp_xpu_float16",
+ "test_compare_cpu_logcumsumexp_xpu_float32",
+ "test_compare_cpu_nn_functional_pdist_xpu_float32",
+ "test_compare_cpu_tril_indices_xpu_int32",
+ "test_compare_cpu_tril_indices_xpu_int64",
+ "test_compare_cpu_triu_indices_xpu_int32",
+ "test_compare_cpu_triu_indices_xpu_int64",
+ "test_backward_logcumsumexp_xpu_float32",
+ "test_backward_nn_functional_pdist_xpu_float32",
+ "test_forward_ad_logcumsumexp_xpu_float32",
+ "test_operator_logcumsumexp_xpu_float32",
+ "test_operator_nn_functional_pdist_xpu_float32",
+ "test_view_replay_logcumsumexp_xpu_float32",
+ "test_view_replay_nn_functional_pdist_xpu_float32",
),
}
diff --git a/test/xpu/extended/skip_list_common.py b/test/xpu/extended/skip_list_common.py
index 6b5fd653e..643d631eb 100644
--- a/test/xpu/extended/skip_list_common.py
+++ b/test/xpu/extended/skip_list_common.py
@@ -194,5 +194,9 @@
# Greatest absolute difference: 0.0625 at index (1,) (up to 0.001 allowed)
# Greatest relative difference: 0.00640869140625 at index (1,) (up to 0.001 allowed)
"test_compare_cpu_xlogy_xpu_bfloat16",
+ "test_compare_cpu_div_trunc_rounding_xpu_float64",
+ "test_compare_cpu_div_trunc_rounding_xpu_float16",
+ "test_compare_cpu_div_floor_rounding_xpu_float16",
+ "test_compare_cpu_div_floor_rounding_xpu_bfloat16",
),
}
diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py
index e1903f871..52a93d91b 100644
--- a/test/xpu/skip_list_common.py
+++ b/test/xpu/skip_list_common.py
@@ -649,6 +649,14 @@
"test_python_ref__refs_square_xpu_complex64",
"test_python_ref_torch_fallback__refs_square_xpu_complex64",
"test_python_ref_torch_fallback__refs_exp_xpu_complex128",
+
+ # Failed on rolling driver, passed on preci
+ "test_python_ref__refs_div_trunc_rounding_xpu_float64",
+ "test_python_ref_executor__refs_div_trunc_rounding_executor_aten_xpu_float64",
+ "test_python_ref_torch_fallback__refs_div_trunc_rounding_xpu_float64",
+
+ # TODO: passed from source code building version, investigate
+ "test_python_ref__refs_log2_xpu_complex128",
),
"test_binary_ufuncs_xpu.py": (
@@ -939,38 +947,6 @@
# CPU fallback fails
# RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
- # aten::_thnn_fused_gru_cell not support XPU backend
- "test_save_load_nn_GRU_eval_mode_xpu_float32",
- "test_save_load_nn_GRUCell_xpu_float32",
- "test_save_load_nn_GRU_train_mode_xpu_float32",
-
- # aten::_thnn_fused_lstm_cell not support XPU backend
- # Could not run 'aten::_thnn_fused_lstm_cell' with arguments from the 'CPU' backend.
- "_LSTM_",
- "_LSTMCell_",
-
- # aten::_thnn_fused_gru_cell not support XPU backend
- # CPU fallback fails
- # Could not run 'aten::_thnn_fused_gru_cell' with arguments from the 'CPU' backend.
- "test_to_nn_GRUCell_swap_True_set_grad_False_xpu_float32",
- "test_to_nn_GRU_eval_mode_swap_True_set_grad_False_xpu_float32",
- "test_to_nn_GRU_train_mode_swap_True_set_grad_False_xpu_float32 ",
- "test_cpu_gpu_parity_nn_GRUCell_xpu_float32",
- "test_cpu_gpu_parity_nn_GRU_eval_mode_xpu_float32",
- "test_cpu_gpu_parity_nn_GRU_train_mode_xpu_float32",
- "test_forward_nn_GRUCell_xpu_float32",
- "test_forward_nn_GRU_eval_mode_xpu_float32",
- "test_forward_nn_GRU_train_mode_xpu_float32",
- "test_if_train_and_eval_modes_differ_nn_GRUCell_xpu_float32",
- "test_memory_format_nn_GRUCell_xpu_float32",
- "test_memory_format_nn_GRU_eval_mode_xpu_float32",
- "test_memory_format_nn_GRU_train_mode_xpu_float32",
- "test_multiple_device_transfer_nn_GRUCell_xpu_float32",
- "test_multiple_device_transfer_nn_GRU_eval_mode_xpu_float32",
- "test_multiple_device_transfer_nn_GRU_train_mode_xpu_float32",
- "test_non_contiguous_tensors_nn_GRUCell_xpu_float32",
- "test_non_contiguous_tensors_nn_GRU_eval_mode_xpu_float32",
- "test_non_contiguous_tensors_nn_GRU_train_mode_xpu_float32",
# AssertionError: False is not true
"test_to_nn_BatchNorm1d_eval_mode_swap_True_set_grad_True_xpu_float32",
"test_to_nn_BatchNorm1d_train_mode_swap_True_set_grad_True_xpu_float32",
@@ -1118,8 +1094,6 @@
# Sometimes, will raise AssertionError: "Simulate error" does not match "grad can be implicitly created only for scalar outputs"
# https://github.com/intel/torch-xpu-ops/issues/1071
"test_reentrant_parent_error_on_cpu_xpu",
- # Could not run 'aten::_thnn_fused_lstm_cell' with arguments from the 'CPU' backend.
- "test_rnn_backward_to_input_but_not_parameters_xpu",
),
"test_reductions_xpu.py": (
@@ -1170,6 +1144,7 @@
# Greatest relative difference: 1.9145216356264427e-05 at index (463, 204) (up to 1.3e-06 allowed)
"test_reference_numerics_normal__refs_asinh_xpu_complex64",
"test_reference_numerics_normal_asinh_xpu_complex64",
+ "test_batch_vs_slicing__refs_sigmoid_xpu_complex128",
# Unexpected success: CUDA uses thrust::sqrt and has accuracy issue. XPU use std::sqrt and has no issue.
"test_reference_numerics_large_rsqrt_xpu_complex32",
# Numeric difference
@@ -1548,6 +1523,8 @@
# XPU does not support tunable.
"test_bmm_tunableop_rocm_xpu_float32",
"test_numeric_check_leak_tunableop_rocm_xpu_float32",
+ "test_dump_results_on_exit_tunableop_xpu_float32",
+ "test_rotating_buffer_tunableop_xpu_float32",
# CUDA bias cases added in latest PyTorch
# AttributeError: module 'torch._C' has no attribute '_cuda_tunableop_enable'
"test_matmul_check_entries_tunableop_xpu_float16",
@@ -3264,7 +3241,10 @@
"test_type_promotion_xpu.py": None,
- "test_distributions_xpu.py": None,
+ "test_distributions_xpu.py": (
+ # TODO: Passed on lts driver version, but failed on rolling driver version
+ "test_gamma_gpu_sample_xpu",
+ ),
"test_optim_xpu.py": (
# oneDNN issues
diff --git a/test/xpu/test_unary_ufuncs_xpu.py b/test/xpu/test_unary_ufuncs_xpu.py
index 0e05a8e7c..a6c12a2ad 100644
--- a/test/xpu/test_unary_ufuncs_xpu.py
+++ b/test/xpu/test_unary_ufuncs_xpu.py
@@ -1,6 +1,7 @@
# Owner(s): ["module: intel"]
-from torch.testing._internal.common_device_type import instantiate_device_type_tests
+import torch
+from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyXPU
from torch.testing._internal.common_utils import run_tests
try:
@@ -11,6 +12,38 @@
with XPUPatchForImport(False):
from test_unary_ufuncs import TestUnaryUfuncs
+ @onlyXPU
+ def _nonzero_static_large(self, device):
+ # large enough to have multiple iters per SM even on H100
+ # with 132 sms
+ size_inp = 1024 * 16 * 132 + 1024 * 16
+ x = torch.zeros(size_inp, device=device)
+ # unique indices
+ indices = torch.randperm(size_inp, device=device)[: size_inp // 2]
+ sorted, _ = torch.sort(indices)
+ x[sorted] = 1
+ res = torch.nonzero_static(x, size=size_inp // 2).view(-1)
+ self.assertEqual(res, sorted)
+ # no oob writes
+ out = torch.full((size_inp,), 10, device=device, dtype=torch.int64)
+ res = torch.nonzero_static(x, size=size_inp // 4, out=out[: size_inp // 2])
+ self.assertEqual(out[: size_inp // 4], sorted[: size_inp // 4])
+ self.assertEqual(
+ out[size_inp // 4 :],
+ torch.tensor(10, device="xpu").expand_as(out[size_inp // 4 :]),
+ )
+ # correct fill for 2d
+ x = x.view(2, size_inp // 2)
+ ref = x.nonzero()
+ res = x.nonzero_static(size=size_inp // 2 + 2)
+ self.assertEqual(res.shape, [size_inp // 2 + 2, 2])
+ self.assertEqual(ref, res[: size_inp // 2])
+ self.assertEqual(
+ res[size_inp // 2 :],
+ torch.tensor(-1, device="xpu").expand_as(res[size_inp // 2 :]),
+ )
+ TestUnaryUfuncs.test_nonzero_static_large = _nonzero_static_large
+
instantiate_device_type_tests(TestUnaryUfuncs, globals(),only_for=("xpu"), allow_xpu=True)
if __name__ == "__main__":
diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py
index 6c31415cc..1d18a27e2 100644
--- a/test/xpu/xpu_test_utils.py
+++ b/test/xpu/xpu_test_utils.py
@@ -223,6 +223,8 @@
"nn.functional.ctc_loss",
"nn.functional.channel_shuffle",
"nn.functional.multi_head_attention_forward",
+ "nn.GRUCell",
+ "nn.LSTMCell",
"sigmoid",
"logsigmoid",
"sgn",
diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml
index f76f49fb8..cbd57c762 100644
--- a/yaml/native/native_functions.yaml
+++ b/yaml/native/native_functions.yaml
@@ -7572,6 +7572,34 @@
dispatch:
XPU: ctc_loss_backward_tensor
+- func: lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> (Tensor, Tensor)
+
+- func: gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> Tensor
+
+# Fused RNN kernels
+- func: _thnn_fused_lstm_cell(Tensor input_gates, Tensor hidden_gates, Tensor cx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor, Tensor)
+ dispatch:
+ XPU: _thnn_fused_lstm_cell_xpu
+ autogen: _thnn_fused_lstm_cell.out
+
+# NB: The composite version of this function below is a simple wrapper that duplicates some of the outputs
+# It is necessary to avoid triggering TensorImpl use count checks in debug mode
+# NB: this is function is NOT differentiable
+- func: _thnn_fused_lstm_cell_backward_impl(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor)
+ dispatch:
+ XPU: _thnn_fused_lstm_cell_backward_xpu
+ autogen: _thnn_fused_lstm_cell_backward_impl.out
+
+- func: _thnn_fused_gru_cell(Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor)
+ dispatch:
+ XPU: _thnn_fused_gru_cell_xpu
+ autogen: _thnn_fused_gru_cell.out
+
+- func: _thnn_fused_gru_cell_backward(Tensor grad_hy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
+ dispatch:
+ XPU: _thnn_fused_gru_cell_backward_xpu
+ autogen: _thnn_fused_gru_cell_backward.out
+
- func: hardshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase