-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add aten::lshift/rshift/prelu and thieir variants (#688)
- [x] prelu - [x] __lshift__ - [x] __ilshift__ - [x] __rshift__ - [x] __irshift__
- Loading branch information
1 parent
a818677
commit 00c9f3e
Showing
9 changed files
with
236 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
#include <ATen/ATen.h> | ||
#include <ATen/Dispatch.h> | ||
#include <ATen/native/TensorIterator.h> | ||
|
||
#include <ATen/native/xpu/sycl/Loops.h> | ||
|
||
namespace at::native::xpu { | ||
|
||
template <typename scalar_t> | ||
struct PreluFunctor { | ||
scalar_t operator()(scalar_t input, scalar_t weight) const { | ||
return (input > 0) ? input : weight * input; | ||
} | ||
}; | ||
|
||
template <typename scalar_t> | ||
struct PreluBackwardFunctor { | ||
std::tuple<scalar_t, scalar_t> operator()( | ||
scalar_t input, | ||
scalar_t weight, | ||
scalar_t grad) const { | ||
auto mask = input > 0; | ||
auto grad_input = mask ? grad : weight * grad; | ||
auto grad_weight = mask ? scalar_t{0} : input * grad; | ||
return std::tuple<scalar_t, scalar_t>{grad_input, grad_weight}; | ||
} | ||
}; | ||
|
||
void prelu_kernel(TensorIterator& iter) { | ||
AT_DISPATCH_FLOATING_TYPES_AND2( | ||
kBFloat16, kHalf, iter.dtype(), "prelu_xpu", [&] { | ||
gpu_kernel(iter, PreluFunctor<scalar_t>()); | ||
}); | ||
} | ||
|
||
void prelu_backward_kernel(TensorIterator& iter) { | ||
AT_DISPATCH_FLOATING_TYPES_AND2( | ||
kBFloat16, kHalf, iter.dtype(), "prelu_backward_xpu", [&] { | ||
gpu_kernel_multiple_outputs(iter, PreluBackwardFunctor<scalar_t>()); | ||
}); | ||
} | ||
|
||
} // namespace at::native::xpu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
#pragma once | ||
|
||
#include <ATen/native/TensorIterator.h> | ||
|
||
namespace at::native::xpu { | ||
|
||
void prelu_kernel(TensorIterator& iter); | ||
|
||
void prelu_backward_kernel(TensorIterator& iter); | ||
|
||
} // namespace at::native::xpu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
#include <ATen/ATen.h> | ||
#include <ATen/Dispatch.h> | ||
#include <ATen/native/TensorIterator.h> | ||
|
||
#include <ATen/native/xpu/sycl/Loops.h> | ||
|
||
namespace at::native::xpu { | ||
|
||
template <typename scalar_t> | ||
struct LshiftFunctor { | ||
scalar_t operator()(scalar_t a, scalar_t b) const { | ||
constexpr scalar_t max_shift = sizeof(scalar_t) * CHAR_BIT; | ||
if ((static_cast<std::make_signed_t<scalar_t>>(b) < 0) || | ||
(b >= max_shift)) { | ||
return 0; | ||
} | ||
return static_cast<std::make_unsigned_t<scalar_t>>(a) << b; | ||
} | ||
}; | ||
|
||
void lshift_kernel(TensorIteratorBase& iter) { | ||
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "lshift_xpu", [&]() { | ||
gpu_kernel_with_scalars(iter, LshiftFunctor<scalar_t>()); | ||
}); | ||
} | ||
|
||
template <typename scalar_t> | ||
struct RshiftFunctor { | ||
scalar_t operator()(scalar_t a, scalar_t b) const { | ||
// right shift value to retain sign bit for signed and no bits for | ||
// unsigned | ||
constexpr scalar_t max_shift = | ||
sizeof(scalar_t) * CHAR_BIT - std::is_signed_v<scalar_t>; | ||
if ((static_cast<std::make_signed_t<scalar_t>>(b) < 0) || | ||
(b >= max_shift)) { | ||
return a >> max_shift; | ||
} | ||
return a >> b; | ||
} | ||
}; | ||
|
||
void rshift_kernel(TensorIteratorBase& iter) { | ||
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "rshift_xpu", [&]() { | ||
gpu_kernel_with_scalars(iter, RshiftFunctor<scalar_t>()); | ||
}); | ||
} | ||
|
||
} // namespace at::native::xpu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
#pragma once | ||
|
||
#include <ATen/native/TensorIterator.h> | ||
|
||
namespace at::native::xpu { | ||
|
||
void lshift_kernel(TensorIteratorBase& iter); | ||
|
||
void rshift_kernel(TensorIteratorBase& iter); | ||
|
||
} // namespace at::native::xpu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters