Skip to content

[SPIRV] Cast derivative opts to 32-bits. #7445

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 82 additions & 6 deletions tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9484,12 +9484,17 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
retVal = processIntrinsicPointerCast(callExpr, true);
break;
}
INTRINSIC_SPIRV_OP_CASE(ddx, DPdx, true);
INTRINSIC_SPIRV_OP_CASE(ddx_coarse, DPdxCoarse, false);
INTRINSIC_SPIRV_OP_CASE(ddx_fine, DPdxFine, false);
INTRINSIC_SPIRV_OP_CASE(ddy, DPdy, true);
INTRINSIC_SPIRV_OP_CASE(ddy_coarse, DPdyCoarse, false);
INTRINSIC_SPIRV_OP_CASE(ddy_fine, DPdyFine, false);
case hlsl::IntrinsicOp::IOP_ddx:
case hlsl::IntrinsicOp::IOP_ddx_coarse:
case hlsl::IntrinsicOp::IOP_ddx_fine:
case hlsl::IntrinsicOp::IOP_ddy:
case hlsl::IntrinsicOp::IOP_ddy_coarse:
case hlsl::IntrinsicOp::IOP_ddy_fine: {
retVal = processDerivativeIntrinsic(hlslOpcode, callExpr->getArg(0),
callExpr->getExprLoc(),
callExpr->getSourceRange());
break;
}
INTRINSIC_SPIRV_OP_CASE(countbits, BitCount, false);
INTRINSIC_SPIRV_OP_CASE(fmod, FRem, true);
INTRINSIC_SPIRV_OP_CASE(fwidth, Fwidth, true);
Expand Down Expand Up @@ -9572,6 +9577,77 @@ SpirvEmitter::processIntrinsicFirstbit(const CallExpr *callExpr,
srcRange);
}

SpirvInstruction *SpirvEmitter::processMatrixDerivativeIntrinsic(
hlsl::IntrinsicOp hlslOpcode, const Expr *arg, SourceLocation loc,
SourceRange range) {
const auto actOnEachVec = [this, hlslOpcode, loc, range](
uint32_t /*index*/, QualType inType,
QualType outType, SpirvInstruction *curRow) {
return processDerivativeIntrinsic(hlslOpcode, curRow, loc, range);
};

return processEachVectorInMatrix(arg, arg->getType(), doExpr(arg),
actOnEachVec, loc, range);
}

SpirvInstruction *
SpirvEmitter::processDerivativeIntrinsic(hlsl::IntrinsicOp hlslOpcode,
const Expr *arg, SourceLocation loc,
SourceRange range) {
if (isMxNMatrix(arg->getType())) {
return processMatrixDerivativeIntrinsic(hlslOpcode, arg, loc, range);
}
return processDerivativeIntrinsic(hlslOpcode, doExpr(arg), loc, range);
}

SpirvInstruction *SpirvEmitter::processDerivativeIntrinsic(
hlsl::IntrinsicOp hlslOpcode, SpirvInstruction *arg, SourceLocation loc,
SourceRange range) {
QualType returnType = arg->getAstResultType();
assert(isFloatOrVecOfFloatType(returnType));

if (!spvContext.isPS())
addDerivativeGroupExecutionMode();
needsLegalization = true;

QualType B32Type = astContext.FloatTy;
uint32_t vectorSize = 0;
QualType elementType = returnType;
if (isVectorType(returnType, &elementType, &vectorSize)) {
B32Type = astContext.getExtVectorType(B32Type, vectorSize);
}

// Derivative operations work on 32-bit floats only. Cast to 32-bit if needed.
SpirvInstruction *operand = castToType(arg, returnType, B32Type, loc, range);

spv::Op opcode = spv::Op::OpNop;
switch (hlslOpcode) {
case hlsl::IntrinsicOp::IOP_ddx:
opcode = spv::Op::OpDPdx;
break;
case hlsl::IntrinsicOp::IOP_ddx_coarse:
opcode = spv::Op::OpDPdxCoarse;
break;
case hlsl::IntrinsicOp::IOP_ddx_fine:
opcode = spv::Op::OpDPdxFine;
break;
case hlsl::IntrinsicOp::IOP_ddy:
opcode = spv::Op::OpDPdy;
break;
case hlsl::IntrinsicOp::IOP_ddy_coarse:
opcode = spv::Op::OpDPdyCoarse;
break;
case hlsl::IntrinsicOp::IOP_ddy_fine:
opcode = spv::Op::OpDPdyFine;
break;
};

SpirvInstruction *result =
spvBuilder.createUnaryOp(opcode, B32Type, operand, loc, range);
result = castToType(result, B32Type, returnType, loc, range);
return result;
}

// Returns true is the given expression can be used as an output parameter.
//
// Warning: this function could return false negatives.
Expand Down
15 changes: 15 additions & 0 deletions tools/clang/lib/SPIRV/SpirvEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,21 @@ class SpirvEmitter : public ASTConsumer {
SpirvInstruction *processIntrinsicFirstbit(const CallExpr *,
GLSLstd450 glslOpcode);

SpirvInstruction *
processMatrixDerivativeIntrinsic(hlsl::IntrinsicOp hlslOpcode,
const Expr *arg, SourceLocation loc,
SourceRange range);

SpirvInstruction *processDerivativeIntrinsic(hlsl::IntrinsicOp hlslOpcode,
const Expr *arg,
SourceLocation loc,
SourceRange range);

SpirvInstruction *processDerivativeIntrinsic(hlsl::IntrinsicOp hlslOpcode,
SpirvInstruction *arg,
SourceLocation loc,
SourceRange range);

private:
/// Returns the <result-id> for constant value 0 of the given type.
SpirvConstant *getValueZero(QualType type);
Expand Down
21 changes: 21 additions & 0 deletions tools/clang/test/CodeGenSPIRV/intrinsics.ddx.double.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// RUN: %dxc -T ps_6_2 -E main -fcgl %s -spirv 2>&1 | FileCheck %s

// CHECK: :14:22: warning: conversion from larger type 'double' to smaller type 'float', possible loss of data [-Wconversion]
// CHECK: :20:22: warning: conversion from larger type 'double2' to smaller type 'vector<float, 2>', possible loss of data [-Wconversion]

void main() {
double a;
double2 b;

// CHECK: [[a:%[0-9]+]] = OpLoad %double %a
// CHECK-NEXT: [[c:%[0-9]+]] = OpFConvert %float [[a]]
// CHECK-NEXT: [[r:%[0-9]+]] = OpDPdx %float [[c]]
// CHECK-NEXT: OpFConvert %double [[r]]
double da = ddx(a);

// CHECK: [[b:%[0-9]+]] = OpLoad %v2double %b
// CHECK-NEXT: [[c:%[0-9]+]] = OpFConvert %v2float [[b]]
// CHECK-NEXT: [[r:%[0-9]+]] = OpDPdx %v2float [[c]]
// CHECK-NEXT: OpFConvert %v2double [[r]]
double2 db = ddx(b);
}
19 changes: 19 additions & 0 deletions tools/clang/test/CodeGenSPIRV/intrinsics.ddx.half.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: %dxc -T ps_6_2 -E main -enable-16bit-types -fcgl %s -spirv | FileCheck %s

void main() {

half a;
half2 b;

// CHECK: [[a:%[0-9]+]] = OpLoad %half %a
// CHECK-NEXT: [[c:%[0-9]+]] = OpFConvert %float [[a]]
// CHECK-NEXT: [[r:%[0-9]+]] = OpDPdx %float [[c]]
// CHECK-NEXT: OpFConvert %half [[r]]
half da = ddx(a);

// CHECK: [[b:%[0-9]+]] = OpLoad %v2half %b
// CHECK-NEXT: [[c:%[0-9]+]] = OpFConvert %v2float [[b]]
// CHECK-NEXT: [[r:%[0-9]+]] = OpDPdx %v2float [[c]]
// CHECK-NEXT: OpFConvert %v2half [[r]]
half2 db = ddx(b);
}