You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Current tests use double-precision constants passed to where(), which works. There is currently no using Float = Scalar<float> scalar defined but we'd like to extend where to support single precision arguments. However, adding that type, currently this fails to compile:
TEST_F(NVFuserTest, FusionWhereFloat32_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
auto tv0 = makeSymbolicTensor(1, DataType::Bool);
fusion.addInput(tv0);
using Float = Scalar<float>; // no built-in Float scalar, so we define it auto tv1 = where(tv0,
IrBuilder::create<Float>(3.0),
IrBuilder::create<Float>(5.0)
);
fusion.addOutput(tv1);
auto options = at::TensorOptions().dtype(at::kBool).device(at::kCUDA, 0);
auto t0 = at::randint(0, 1, {5}, options);
auto ref = at::where(t0, (float)3.0, (float)5.0);
std::vector<IValue> inputs = {t0};
auto lparams = schedulePointwise(&fusion, inputs);
FusionExecutor fe;
fe.compileFusion(&fusion, inputs, lparams);
/*C++ exception with description "false INTERNAL ASSERT FAILED at "/opt/pytorch/pytorch/third_party/nvfuser/csrc/executor_utils.cpp":1237, please report a bug to PyTorch....__global__ void kernel1(Tensor<bool, 1> T0, Tensor<float, 1> T1) { int i59; i59 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x); if ((i59 < T0.size[0])) { bool T2[1]; T2[0] = 0; T2[0] = T0[(((T0.stride[0] * ((nvfuser_index_t)blockIdx.x)) * 128) + (T0.stride[0] * ((nvfuser_index_t)threadIdx.x)))]; float T3[1]; T3[0] = where(T2[0], f3, f2); T1[i59] = T3[0]; }}}CUDA NVRTC compile error: __tmp_kernel1.cu(8920): error: identifier "f3" is undefined__tmp_kernel1.cu(8920): error: identifier "f2" is undefined2 errors detected in the compilation of "__tmp_kernel1.cu".*/auto cg_outputs = fe.runFusion(inputs);
testValidate(&fusion, cg_outputs, inputs, {ref}, __LINE__, __FILE__);
}
If this is the only known use case for single-precision scalars, then it may be simpler to add a dtype argument to the C++ where. Otherwise, it seems we may need to add some more to the codegen to be aware of scalars other than Double so that they appear in the kernel signature.
Versions
Collecting environment information...
PyTorch version: 2.0.0a0+git4121ffc
Is debug build: False
CUDA used to build PyTorch: 12.0
ROCM used to build PyTorch: N/A
🐛 Describe the bug
Current tests use double-precision constants passed to
where()
, which works. There is currently nousing Float = Scalar<float>
scalar defined but we'd like to extendwhere
to support single precision arguments. However, adding that type, currently this fails to compile:If this is the only known use case for single-precision scalars, then it may be simpler to add a dtype argument to the C++
where
. Otherwise, it seems we may need to add some more to the codegen to be aware of scalars other thanDouble
so that they appear in the kernel signature.Versions
The text was updated successfully, but these errors were encountered: