Skip to content

Commit cbb4ab1

Browse files
authored
Add aten::huber_loss/backward and their variants (#562)
Task list: - [x] huber_loss - [x] huber_loss.out - [x] huber_loss_backward.out
1 parent 34f00ad commit cbb4ab1

File tree

9 files changed

+131
-4
lines changed

9 files changed

+131
-4
lines changed

src/ATen/native/xpu/Loss.cpp

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,24 @@
11
#include <ATen/ATen.h>
22
#include <ATen/core/Reduction.h>
33
#include <ATen/core/Tensor.h>
4-
#include <ATen/xpu/XPUNativeFunctions.h>
5-
64
#include <ATen/native/xpu/sycl/BinaryMiscOpsKernels.h>
75
#include <ATen/native/xpu/sycl/PointwiseOpsKernels.h>
6+
#include <ATen/xpu/XPUNativeFunctions.h>
87
#include <comm/RegisterUtils.h>
98

109
namespace at {
1110

11+
static inline at::Tensor apply_loss_reduction(
12+
const at::Tensor& unreduced,
13+
int64_t reduction) {
14+
if (reduction == at::Reduction::Mean) {
15+
return unreduced.mean();
16+
} else if (reduction == at::Reduction::Sum) {
17+
return unreduced.sum();
18+
}
19+
return unreduced;
20+
}
21+
1222
Tensor& XPUNativeFunctions::mse_loss_out(
1323
const Tensor& input,
1424
const Tensor& target,
@@ -69,4 +79,53 @@ Tensor& XPUNativeFunctions::mse_loss_backward_out(
6979
return grad_input;
7080
}
7181

82+
Tensor XPUNativeFunctions::huber_loss(
83+
const Tensor& input,
84+
const Tensor& target,
85+
int64_t reduction,
86+
double delta) {
87+
TORCH_CHECK(
88+
delta > 0, "huber_loss does not support non-positive values for delta.")
89+
Tensor loss = at::empty_like(input);
90+
auto iter = TensorIterator::borrowing_binary_op(loss, input, target);
91+
native::xpu::huber_kernel(iter, delta);
92+
return apply_loss_reduction(loss, reduction);
93+
}
94+
95+
Tensor& XPUNativeFunctions::huber_loss_out(
96+
const Tensor& input,
97+
const Tensor& target,
98+
int64_t reduction,
99+
double delta,
100+
Tensor& result) {
101+
TORCH_CHECK(
102+
delta > 0, "huber_loss does not support non-positive values for delta.")
103+
auto iter = TensorIterator::borrowing_binary_op(result, input, target);
104+
native::xpu::huber_kernel(iter, delta);
105+
if (reduction != Reduction::None) {
106+
auto reduced = apply_loss_reduction(result, reduction);
107+
result.resize_({});
108+
result.copy_(reduced);
109+
}
110+
return result;
111+
}
112+
113+
Tensor& XPUNativeFunctions::huber_loss_backward_out(
114+
const Tensor& grad_output,
115+
const Tensor& input,
116+
const Tensor& target,
117+
int64_t reduction,
118+
double delta,
119+
Tensor& grad_input) {
120+
auto norm = (reduction == Reduction::Mean) ? (1. / input.numel()) : 1.;
121+
auto iter = at::TensorIteratorConfig()
122+
.add_output(grad_input)
123+
.add_const_input(input)
124+
.add_const_input(target)
125+
.add_const_input(grad_output)
126+
.build();
127+
native::xpu::huber_backward_kernel(iter, norm, delta);
128+
return grad_input;
129+
}
130+
72131
} // namespace at

src/ATen/native/xpu/XPUFallback.template

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
210210
"hardshrink.out",
211211
"heaviside.out",
212212
"histc",
213-
"huber_loss",
214-
"huber_loss_backward.out",
215213
"i0.out",
216214
"igammac.out",
217215
"igamma.out",

src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,25 @@ void mse_kernel(TensorIteratorBase& iter) {
2323
[&]() { gpu_kernel(iter, MSEFunctor<scalar_t>()); });
2424
}
2525

26+
template <typename scalar_t>
27+
struct HuberFunctor {
28+
scalar_t operator()(scalar_t a, scalar_t b) const {
29+
auto z = std::abs(a - b);
30+
return z < delta_val_ ? scalar_t(0.5) * z * z
31+
: delta_val_ * (z - scalar_t(0.5) * delta_val_);
32+
}
33+
HuberFunctor(scalar_t delta_val) : delta_val_(delta_val) {}
34+
35+
private:
36+
scalar_t delta_val_;
37+
};
38+
39+
void huber_kernel(TensorIterator& iter, double delta) {
40+
AT_DISPATCH_FLOATING_TYPES_AND2(
41+
kBFloat16, kHalf, iter.dtype(), "huber_xpu", [&iter, delta] {
42+
scalar_t delta_val(delta);
43+
gpu_kernel(iter, HuberFunctor<scalar_t>(delta_val));
44+
});
45+
}
46+
2647
} // namespace at::native::xpu

src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,6 @@ namespace at::native::xpu {
66

77
void mse_kernel(TensorIteratorBase& iter);
88

9+
void huber_kernel(TensorIterator& iter, double delta);
10+
911
} // namespace at::native::xpu

src/ATen/native/xpu/sycl/PointwiseOpsKernels.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,4 +125,41 @@ void mse_backward_kernel(TensorIterator& iter, const Scalar& value) {
125125
});
126126
}
127127

128+
template <typename scalar_t>
129+
struct HuberBackwardFunctor {
130+
scalar_t operator()(scalar_t input, scalar_t target, scalar_t grad_output)
131+
const {
132+
const auto x = input - target;
133+
if (x < -delta_val_) {
134+
return -norm_val_ * grad_output * delta_val_;
135+
} else if (x > delta_val_) {
136+
return norm_val_ * grad_output * delta_val_;
137+
} else {
138+
return norm_val_ * x * grad_output;
139+
}
140+
}
141+
HuberBackwardFunctor(scalar_t norm_val, scalar_t delta_val)
142+
: norm_val_(norm_val), delta_val_(delta_val) {}
143+
144+
private:
145+
scalar_t norm_val_;
146+
scalar_t delta_val_;
147+
};
148+
149+
void huber_backward_kernel(
150+
TensorIterator& iter,
151+
const Scalar& norm,
152+
double delta) {
153+
AT_DISPATCH_FLOATING_TYPES_AND2(
154+
kBFloat16,
155+
kHalf,
156+
iter.dtype(),
157+
"huber_backward_xpu",
158+
[&iter, &norm, delta] {
159+
auto norm_val = norm.to<scalar_t>();
160+
scalar_t delta_val(delta);
161+
gpu_kernel(iter, HuberBackwardFunctor<scalar_t>(norm_val, delta_val));
162+
});
163+
}
164+
128165
} // namespace at::native::xpu

src/ATen/native/xpu/sycl/PointwiseOpsKernels.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,9 @@ void addcdiv_kernel(TensorIterator& iter, Scalar value);
1010

1111
void mse_backward_kernel(TensorIterator& iter, const Scalar& value);
1212

13+
void huber_backward_kernel(
14+
TensorIterator& iter,
15+
const Scalar& norm,
16+
double delta);
17+
1318
} // namespace at::native::xpu

test/xpu/extended/run_test_with_skip.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@
125125
"test_compare_cpu_nn_functional_batch_norm_xpu_bfloat16",
126126
"test_compare_cpu__batch_norm_with_update_xpu_bfloat16",
127127
"test_compare_cpu__batch_norm_with_update_xpu_float16",
128+
"test_compare_cpu_nn_functional_huber_loss_xpu_bfloat16",
128129

129130
# Not implemented operators, aten::upsample_linear1d, aten::upsample_bilinear2d,
130131
# aten::upsample_trilinear3d

test/xpu/xpu_test_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@
145145
"nn.functional.upsample_nearest",
146146
# "nn.functional.nll_loss", # Lack of XPU implementation of aten::nll_loss2d_forward. Will retrieve the case, only if the op is implemented.
147147
"nn.functional.mse_loss",
148+
"nn.functional.huber_loss",
148149
"sigmoid",
149150
"sgn",
150151
"nn.functional.embedding_bag",

yaml/xpu_functions.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,9 @@ supported:
408408
- nll_loss_forward
409409
- nll_loss_backward.grad_input
410410
- nll_loss_backward
411+
- huber_loss
412+
- huber_loss.out
413+
- huber_loss_backward.out
411414
- batch_norm_stats
412415
- batch_norm_elemt
413416
- batch_norm_elemt.out

0 commit comments

Comments
 (0)