Skip to content

Commit

Permalink
Add support for Trilu<bool>. (microsoft#20917)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
Trilu<bool> is used by phi-3 when exported with torch.onnx.export.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
skottmckay authored Jun 6, 2024
1 parent eb2ec66 commit 3ecf48e
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 314 deletions.
2 changes: 1 addition & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ Do not modify directly.*
|Transpose|*in* data:**T**<br> *out* transposed:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|||[13, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Trilu|*in* input:**T**<br> *in* k:**tensor(int64)**<br> *out* output:**T**|14+|**T** = tensor(double), tensor(float), tensor(int64)|
|Trilu|*in* input:**T**<br> *in* k:**tensor(int64)**<br> *out* output:**T**|14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int64)|
|Unique|*in* X:**T**<br> *out* Y:**T**<br> *out* indices:**tensor(int64)**<br> *out* inverse_indices:**tensor(int64)**<br> *out* counts:**tensor(int64)**|11+|**T** = tensor(double), tensor(float), tensor(int64), tensor(int8), tensor(string)|
|Unsqueeze|*in* data:**T**<br> *in* axes:**tensor(int64)**<br> *out* expanded:**T**<br><br>or<br><br>*in* data:**T**<br> *out* expanded:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[13, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/providers/cpu/tensor/trilu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ ONNX_OPERATOR_KERNEL_EX(
kOnnxDomain,
14,
kCpuExecutionProvider,
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", BuildKernelDefConstraints<float, double, int64_t>()),
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", BuildKernelDefConstraints<float, double, int64_t, bool>()),
Trilu);

template <typename T>
Expand Down Expand Up @@ -110,6 +110,9 @@ Status Trilu::Compute(OpKernelContext* ctx) const {
case sizeof(double):
status = TriluImpl<double>(X, Y, k_val, up);
break;
case sizeof(bool):
status = TriluImpl<bool>(X, Y, k_val, up);
break;
default:
ORT_THROW("Unsupported input data type of ", data_type);
}
Expand Down
Loading

0 comments on commit 3ecf48e

Please sign in to comment.