Skip to content

Commit 8166ade

Browse files
authored
Merge branch 'main' into chao/xccl
2 parents a71447e + 6899263 commit 8166ade

File tree

14 files changed

+1282
-44
lines changed

14 files changed

+1282
-44
lines changed

.github/scripts/env.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ if [ "$1" != "nightly_wheel" ];then
44
source /opt/intel/oneapi/compiler/latest/env/vars.sh
55
source /opt/intel/oneapi/umf/latest/env/vars.sh
66
source /opt/intel/oneapi/pti/latest/env/vars.sh
7+
source /opt/intel/oneapi/ccl/latest/env/vars.sh
8+
source /opt/intel/oneapi/mpi/latest/env/vars.sh
79
else
810
echo "Don't need to source DL-Essential for nightly wheel"
911
fi

.github/workflows/_linux_transformers.yml

Lines changed: 128 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ jobs:
5050
DisableScratchPages: ${{ inputs.driver == 'rolling' && '1' || '0' }}
5151
python: ${{ inputs.python != '' && inputs.python || '3.10' }}
5252
pytorch: ${{ inputs.pytorch != '' && inputs.pytorch || 'nightly' }}
53+
transformers: ${{ inputs.transformers != '' && inputs.transformers || 'v4.47.0' }}
5354
TRANSFORMERS_TEST_DEVICE_SPEC: 'spec.py'
5455
steps:
5556
- name: Checkout torch-xpu-ops
@@ -60,7 +61,7 @@ jobs:
6061
uses: actions/checkout@v4
6162
with:
6263
repository: huggingface/transformers
63-
ref: ${{ inputs.transformers != '' && inputs.transformers || 'v4.47.0' }}
64+
ref: ${{ env.transformers }}
6465
path: transformers
6566
- name: Prepare OS environment
6667
run: |
@@ -103,13 +104,12 @@ jobs:
103104
rm -rf reports
104105
cp ${{ github.workspace }}/torch-xpu-ops/.github/scripts/spec.py ./
105106
- name: Report installed versions
106-
id: installed
107107
run: |
108108
source activate huggingface_transformers_test
109-
echo "TORCH_BRANCH_ID=$(python -c 'import torch; print(torch.__version__)')" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}"
110-
echo "TORCH_COMMIT_ID=$(python -c 'import torch; print(torch.version.git_version)')" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}"
111109
echo "pip installed packages:"
112110
pip list | tee ${{ github.workspace }}/transformers/tests_log/pip_list.txt
111+
echo "lspci gpu devices:"
112+
lspci -d ::0380 | tee ${{ github.workspace }}/transformers/tests_log/lspci_0380.txt
113113
echo "GPU render nodes:"
114114
cat /sys/class/drm/render*/device/device | tee ${{ github.workspace }}/transformers/tests_log/device_IDs.txt
115115
- name: Sanitry check installed packages
@@ -120,11 +120,133 @@ jobs:
120120
pip show torch | grep Version | grep xpu
121121
pip show torchaudio | grep Version | grep xpu
122122
pip show torchvision | grep Version | grep xpu
123-
- name: Run XPU backbone
123+
python -c 'import torch; exit(not torch.xpu.is_available())'
124+
- name: Run -k backbone tests
124125
run: |
125126
source activate huggingface_transformers_test
126127
cd transformers
127-
python3 -m pytest -rsf --make-reports=tests_benchmark -k backbone tests
128+
python3 -m pytest -rsf --make-reports=tests_backbone -k backbone tests
129+
- name: Run tests/pipelines
130+
run: |
131+
source activate huggingface_transformers_test
132+
cd transformers
133+
# Some tests are known to fail w/o clear pattern
134+
# TODO: drop ||true after triage and fixes
135+
python3 -m pytest -rsf --make-reports=tests_pipelines tests/pipelines || true
136+
- name: Run tests/trainer
137+
run: |
138+
source activate huggingface_transformers_test
139+
cd transformers
140+
# Excluding tests due to:
141+
# * Some ray tests hang, reason unknown
142+
# * torch.distributed.* not yet supported by XPU
143+
pattern=" \
144+
not ray and \
145+
not TestTrainerDistributed and \
146+
not TestTrainerDistributedXPU and \
147+
not TestFSDPTrainer"
148+
python3 -m pytest -rsf --make-reports=tests_trainer tests/trainer -k "$pattern"
149+
- name: Print results table
150+
if: ${{ ! cancelled() }}
151+
run: |
152+
# Helper function to return number preceeding given pattern, i.e:
153+
# === 25 failed, 11 warnings, 0 errors ===
154+
# Call as follows:
155+
# parse_stat $line "failed"
156+
function parse_stat() {
157+
stat=$(cat $1 | grep $2 | sed "s/.* \([0-9]*\) $2.*/\1/")
158+
if [ -n "$stat" ]; then echo $stat; else echo "0"; fi
159+
}
160+
cd transformers
161+
{
162+
echo "### Results"
163+
echo "| Test group | Errors | Failed | Passed | Skipped |"
164+
echo "| --- | --- | --- | --- | --- |"
165+
for stat in $(find reports -name stats.txt); do
166+
# Each stat.txt is located in: reports/$test_group/stats.txt
167+
test_group=$(echo $stat | cut -f 2 -d/)
168+
# Get failed, passed, skipped, etc. counters
169+
failed=$(parse_stat $stat failed)
170+
passed=$(parse_stat $stat passed)
171+
skipped=$(parse_stat $stat skipped)
172+
warnings=$(parse_stat $stat warnings)
173+
errors=$(parse_stat $stat errors)
174+
echo "| $test_group | $errors | $failed | $passed | $skipped |"
175+
done
176+
} >> $GITHUB_STEP_SUMMARY
177+
- name: Print failure lines
178+
if: ${{ ! cancelled() }}
179+
run: |
180+
cd transformers
181+
{
182+
echo "### Failure lines"
183+
echo "| File | Error | Comment |"
184+
echo "| --- | --- | --- |"
185+
rm -rf _failures.txt
186+
for failure in $(find reports -name failures_line.txt); do
187+
tail -n +2 $failure >> _failures.txt
188+
done
189+
# failures_line.txt file does not have test case information,
190+
# so we can just sort the output and report uniq values
191+
sort _failures.txt | uniq > _failures_uniq.txt
192+
while read line; do
193+
file=$(echo $line | cut -f1 -d" " | sed "s/\(.*\):$/\1/")
194+
error=$(echo $line | cut -f2 -d" " | sed "s/\(.*\):$/\1/")
195+
# Failure comments often contain special characters which complicate
196+
# parsing failure lines. But fortunately we know for sure where comments
197+
# start. So we just output all contents starting from this position and
198+
# wrap everything in <pre></pre> to avoid collisions with Markdown formatting.
199+
comment="<pre>$(echo $line | cut -f3- -d' ' | sed 's/\(.*\):$/\1/')</pre>"
200+
echo "| $file | $error | $comment |"
201+
done <_failures_uniq.txt
202+
} >> $GITHUB_STEP_SUMMARY
203+
- name: Print annotations
204+
if: ${{ ! cancelled() }}
205+
run: |
206+
source activate huggingface_transformers_test
207+
{
208+
echo "### Annotations"
209+
echo "| | |"
210+
echo "| --- | --- |"
211+
echo "| jobs.$GITHUB_JOB.versions.os | $(source /etc/os-release && echo $VERSION_ID) |"
212+
echo "| jobs.$GITHUB_JOB.versions.linux-kernel | $(uname -r) |"
213+
echo "| jobs.$GITHUB_JOB.versions.python | $(python --version | cut -f2 -d' ') |"
214+
packages=" \
215+
level-zero \
216+
libigc1 \
217+
libigc2 \
218+
libze1 \
219+
libze-intel-gpu1 \
220+
intel-i915-dkms \
221+
intel-level-zero-gpu \
222+
intel-opencl-icd"
223+
for package in $packages; do
224+
package_version=$(dpkg -l | grep $package | grep ii | head -1 | sed "s/ */ /g" | cut -f3 -d" ")
225+
echo "| jobs.$GITHUB_JOB.versions.$package | $package_version |"
226+
done
227+
packages="accelerate \
228+
numpy \
229+
torch \
230+
torchaudio \
231+
torchvision \
232+
transformers"
233+
for package in $packages; do
234+
package_version=$(python -c "import $package; print($package.__version__)" || true)
235+
echo "| jobs.$GITHUB_JOB.versions.$package | $package_version |"
236+
done
237+
# printing annotations for GPU cards
238+
var="[$(cat /sys/class/drm/render*/device/vendor || true)]"
239+
echo "| jobs.$GITHUB_JOB.drm.render_nodes_vendor_ids | $(echo $var | sed 's/ /,/g') |"
240+
var="[$(cat /sys/class/drm/render*/device/device || true)]"
241+
echo "| jobs.$GITHUB_JOB.drm.render_nodes_device_ids | $(echo $var | sed 's/ /,/g') |"
242+
var=$(python -c "import torch; print(torch.version.xpu)" || true)
243+
echo "| jobs.$GITHUB_JOB.torch.version.xpu | $var |"
244+
var=$(python -c "import torch; print(torch.xpu.device_count())" || true)
245+
echo "| jobs.$GITHUB_JOB.torch.xpu.device_count | $var |"
246+
# printing annotations with key environment variables
247+
echo "| jobs.$GITHUB_JOB.env.ZE_AFFINITY_MASK | $ZE_AFFINITY_MASK |"
248+
echo "| jobs.$GITHUB_JOB.env.NEOReadDebugKeys | $NEOReadDebugKeys |"
249+
} >> $GITHUB_STEP_SUMMARY
128250
- name: Upload Test log
129251
if: ${{ ! cancelled() }}
130252
uses: actions/upload-artifact@v4

src/ATen/native/xpu/RNN.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/native/xpu/sycl/RNNKernels.h>
3+
4+
namespace at::native {
5+
6+
std::tuple<Tensor, Tensor, Tensor> _thnn_fused_lstm_cell_xpu(
7+
const Tensor& input_gates,
8+
const Tensor& hidden_gates,
9+
const Tensor& cx,
10+
const std::optional<Tensor>& input_bias_opt,
11+
const std::optional<Tensor>& hidden_bias_opt) {
12+
return native::xpu::_thnn_fused_lstm_cell_kernel(
13+
input_gates, hidden_gates, cx, input_bias_opt, hidden_bias_opt);
14+
}
15+
16+
std::tuple<Tensor, Tensor, Tensor> _thnn_fused_lstm_cell_backward_xpu(
17+
const std::optional<Tensor>& grad_hy_opt,
18+
const std::optional<Tensor>& grad_cy_opt,
19+
const Tensor& cx,
20+
const Tensor& cy,
21+
const Tensor& workspace,
22+
bool has_bias) {
23+
return native::xpu::_thnn_fused_lstm_cell_backward_kernel(
24+
grad_hy_opt, grad_cy_opt, cx, cy, workspace, has_bias);
25+
}
26+
27+
std::tuple<at::Tensor, at::Tensor> _thnn_fused_gru_cell_xpu(
28+
const Tensor& input_gates,
29+
const Tensor& hidden_gates,
30+
const Tensor& hx,
31+
const std::optional<at::Tensor>& input_bias,
32+
const std::optional<at::Tensor>& hidden_bias) {
33+
return native::xpu::_thnn_fused_gru_cell_kernel(
34+
input_gates, hidden_gates, hx, input_bias, hidden_bias);
35+
}
36+
37+
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>
38+
_thnn_fused_gru_cell_backward_xpu(
39+
const Tensor& grad_hy,
40+
const Tensor& workspace,
41+
bool has_bias) {
42+
return native::xpu::_thnn_fused_gru_cell_backward_kernel(
43+
grad_hy, workspace, has_bias);
44+
}
45+
46+
} // namespace at::native

src/ATen/native/xpu/XPUFallback.template

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
185185
"lu_unpack.out",
186186
"ormqr",
187187
"_scaled_mm",
188-
"_thnn_fused_gru_cell",
189188
"_to_sparse_csr",
190189
"triangular_solve.X",
191190
"_validate_compressed_sparse_indices",

0 commit comments

Comments
 (0)