Skip to content

Commit

Permalink
Allow native vectors for LLVM operations
Browse files Browse the repository at this point in the history
Disables various forms of scalarization and vector elimination to permit
vectors to pass through to final DXIL when used in native LLVM
operations and loading/storing.

Introduces a few vector manipulation llvm instructions to DXIL allowing
for them to appear in output DXIL.

Skips passes for 6.9 that scalarize, convert to arrays, or otherwise eliminate vectors.
This eliminates the element-by-element loading of the vectors
In many cases, this required plumbing the shader model information to
passes that didn't have it before.

Many changes were needed for the MatrixBitcastLower pass related to
linking to avoid converting matrix vectors, but also to perform the
conversion if a shader was compiled for 6.9+, but then linked to a
earlier target.
This now adapts to the linker target to either preserve vectors for 6.9 or arrays for previous versions.
This requires running the DynamicIndexing VectorToArray pass during linking since 6_x and 6_9+ will fail to run this in the initial compile, but will still need to lower vectors to arrays.

Ternary conditional/select operators were element extracted in codegen.
Removing this allows 6.9 to preserve the vectors, but also maintains
behavior for previous shader models because the operations get
scalarized later anyway.

Keep groupshared variables as vectors for 6.9. They are no longer represented as indivual groupshared scalars.

Adds extensive tests for these operations using different types and
sizes and testing them appropriately. Booleans produce significantly
different code, so they get their own test.

Fixes microsoft#7123
  • Loading branch information
pow2clk committed Feb 20, 2025
1 parent e010223 commit 68a284b
Show file tree
Hide file tree
Showing 15 changed files with 1,165 additions and 42 deletions.
36 changes: 36 additions & 0 deletions include/dxc/DXIL/DxilInstructions.h
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,42 @@ struct LlvmInst_VAArg {
bool isAllowed() const { return false; }
};

/// This instruction extracts from vector
struct LlvmInst_ExtractElement {
llvm::Instruction *Instr;
// Construction and identification
LlvmInst_ExtractElement(llvm::Instruction *pInstr) : Instr(pInstr) {}
operator bool() const {
return Instr->getOpcode() == llvm::Instruction::ExtractElement;
}
// Validation support
bool isAllowed() const { return true; }
};

/// This instruction inserts into vector
struct LlvmInst_InsertElement {
llvm::Instruction *Instr;
// Construction and identification
LlvmInst_InsertElement(llvm::Instruction *pInstr) : Instr(pInstr) {}
operator bool() const {
return Instr->getOpcode() == llvm::Instruction::InsertElement;
}
// Validation support
bool isAllowed() const { return true; }
};

/// This instruction Shuffle two vectors
struct LlvmInst_ShuffleVector {
llvm::Instruction *Instr;
// Construction and identification
LlvmInst_ShuffleVector(llvm::Instruction *pInstr) : Instr(pInstr) {}
operator bool() const {
return Instr->getOpcode() == llvm::Instruction::ShuffleVector;
}
// Validation support
bool isAllowed() const { return true; }
};

/// This instruction extracts from aggregate
struct LlvmInst_ExtractValue {
llvm::Instruction *Instr;
Expand Down
2 changes: 2 additions & 0 deletions lib/DxilValidation/DxilValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2158,6 +2158,8 @@ static bool ValidateType(Type *Ty, ValidationContext &ValCtx,
return true;

if (Ty->isVectorTy()) {
if (ValCtx.DxilMod.GetShaderModel()->IsSM69Plus())
return true;
ValCtx.EmitTypeError(Ty, ValidationRule::TypesNoVector);
return false;
}
Expand Down
6 changes: 6 additions & 0 deletions lib/HLSL/DxilLinker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1255,6 +1255,12 @@ void DxilLinkJob::RunPreparePass(Module &M) {
// For static global handle.
PM.add(createLowerStaticGlobalIntoAlloca());

// Change dynamic indexing vector to array where vectors aren't
// supported, but might be there from the initial compile.
if (!pSM->IsSM69Plus())
PM.add(
createDynamicIndexingVectorToArrayPass(false /* ReplaceAllVector */));

// Remove MultiDimArray from function call arg.
PM.add(createMultiDimArrayToOneDimArrayPass());

Expand Down
44 changes: 28 additions & 16 deletions lib/HLSL/HLMatrixBitcastLowerPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,13 @@ class MatrixBitcastLowerPass : public FunctionPass {

// Lower matrix first.
for (BitCastInst *BCI : matCastSet) {
lowerMatrix(BCI, BCI->getOperand(0));
lowerMatrix(DM, BCI, BCI->getOperand(0));
}
return bUpdated;
}

private:
void lowerMatrix(Instruction *M, Value *A);
void lowerMatrix(DxilModule &DM, Instruction *M, Value *A);
bool hasCallUser(Instruction *M);
};

Expand Down Expand Up @@ -180,7 +180,8 @@ Value *CreateEltGEP(Value *A, unsigned i, Value *zeroIdx,
}
} // namespace

void MatrixBitcastLowerPass::lowerMatrix(Instruction *M, Value *A) {
void MatrixBitcastLowerPass::lowerMatrix(DxilModule &DM, Instruction *M,
Value *A) {
for (auto it = M->user_begin(); it != M->user_end();) {
User *U = *(it++);
if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
Expand All @@ -193,31 +194,42 @@ void MatrixBitcastLowerPass::lowerMatrix(Instruction *M, Value *A) {
SmallVector<Value *, 2> idxList(GEP->idx_begin(), GEP->idx_end());
DXASSERT(idxList.size() == 2,
"else not one dim matrix array index to matrix");

HLMatrixType MatTy = HLMatrixType::cast(EltTy);
Value *matSize = Builder.getInt32(MatTy.getNumElements());
idxList.back() = Builder.CreateMul(idxList.back(), matSize);
if (!DM.GetShaderModel()->IsSM69Plus()) {
HLMatrixType MatTy = HLMatrixType::cast(EltTy);
Value *matSize = Builder.getInt32(MatTy.getNumElements());
idxList.back() = Builder.CreateMul(idxList.back(), matSize);
}
Value *NewGEP = Builder.CreateGEP(A, idxList);
lowerMatrix(GEP, NewGEP);
lowerMatrix(DM, GEP, NewGEP);
DXASSERT(GEP->user_empty(), "else lower matrix fail");
GEP->eraseFromParent();
} else {
DXASSERT(0, "invalid GEP for matrix");
}
} else if (BitCastInst *BCI = dyn_cast<BitCastInst>(U)) {
lowerMatrix(BCI, A);
lowerMatrix(DM, BCI, A);
DXASSERT(BCI->user_empty(), "else lower matrix fail");
BCI->eraseFromParent();
} else if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
if (VectorType *Ty = dyn_cast<VectorType>(LI->getType())) {
IRBuilder<> Builder(LI);
Value *zeroIdx = Builder.getInt32(0);
unsigned vecSize = Ty->getNumElements();
Value *NewVec = UndefValue::get(LI->getType());
for (unsigned i = 0; i < vecSize; i++) {
Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
Value *Elt = Builder.CreateLoad(GEP);
NewVec = Builder.CreateInsertElement(NewVec, Elt, i);
Value *NewVec = nullptr;
if (DM.GetShaderModel()->IsSM69Plus()) {
// Just create a replacement load using the vector pointer.
Instruction *NewLI = LI->clone();
unsigned VecIdx = NewLI->getNumOperands() - 1;
NewLI->setOperand(VecIdx, A);
Builder.Insert(NewLI);
NewVec = NewLI;
} else {
Value *zeroIdx = Builder.getInt32(0);
unsigned vecSize = Ty->getNumElements();
NewVec = UndefValue::get(LI->getType());
for (unsigned i = 0; i < vecSize; i++) {
Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
Value *Elt = Builder.CreateLoad(GEP);
NewVec = Builder.CreateInsertElement(NewVec, Elt, i);
}
}
LI->replaceAllUsesWith(NewVec);
LI->eraseFromParent();
Expand Down
6 changes: 6 additions & 0 deletions lib/Transforms/Scalar/DxilEliminateVector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
// //
///////////////////////////////////////////////////////////////////////////////

#include "dxc/DXIL/DxilModule.h"

#include "llvm/IR/Dominators.h"
#include "llvm/IR/Instructions.h"
#include "llvm/Pass.h"
Expand Down Expand Up @@ -151,6 +153,10 @@ bool DxilEliminateVector::TryRewriteDebugInfoForVector(InsertElementInst *IE) {

bool DxilEliminateVector::runOnFunction(Function &F) {

if (F.getParent()->HasDxilModule())
if (F.getParent()->GetDxilModule().GetShaderModel()->IsSM69Plus())
return false;

auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
DxilValueCache *DVC = &getAnalysis<DxilValueCache>();

Expand Down
18 changes: 15 additions & 3 deletions lib/Transforms/Scalar/LowerTypePasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//===----------------------------------------------------------------------===//

#include "dxc/DXIL/DxilConstants.h"
#include "dxc/DXIL/DxilModule.h"
#include "dxc/DXIL/DxilOperations.h"
#include "dxc/DXIL/DxilUtil.h"
#include "dxc/HLSL/HLModule.h"
Expand Down Expand Up @@ -180,10 +181,12 @@ bool LowerTypePass::runOnModule(Module &M) {
namespace {
class DynamicIndexingVectorToArray : public LowerTypePass {
bool ReplaceAllVectors;
bool SupportsVectors;

public:
explicit DynamicIndexingVectorToArray(bool ReplaceAll = false)
: LowerTypePass(ID), ReplaceAllVectors(ReplaceAll) {}
: LowerTypePass(ID), ReplaceAllVectors(ReplaceAll),
SupportsVectors(false) {}
static char ID; // Pass identification, replacement for typeid
void applyOptions(PassOptions O) override;
void dumpConfig(raw_ostream &OS) override;
Expand All @@ -194,6 +197,7 @@ class DynamicIndexingVectorToArray : public LowerTypePass {
Type *lowerType(Type *Ty) override;
Constant *lowerInitVal(Constant *InitVal, Type *NewTy) override;
StringRef getGlobalPrefix() override { return ".v"; }
void initialize(Module &M) override;

private:
bool HasVectorDynamicIndexing(Value *V);
Expand All @@ -207,6 +211,11 @@ class DynamicIndexingVectorToArray : public LowerTypePass {
void ReplaceAddrSpaceCast(ConstantExpr *CE, Value *A, IRBuilder<> &Builder);
};

void DynamicIndexingVectorToArray::initialize(Module &M) {
if (M.HasHLModule())
SupportsVectors = M.GetHLModule().GetShaderModel()->IsSM69Plus();
}

void DynamicIndexingVectorToArray::applyOptions(PassOptions O) {
GetPassOptionBool(O, "ReplaceAllVectors", &ReplaceAllVectors,
ReplaceAllVectors);
Expand Down Expand Up @@ -286,7 +295,7 @@ void DynamicIndexingVectorToArray::ReplaceStaticIndexingOnVector(Value *V) {
StoreInst *stInst = cast<StoreInst>(GEPUser);
Value *val = stInst->getValueOperand();
Value *ldVal = Builder.CreateLoad(V);
ldVal = Builder.CreateInsertElement(ldVal, val, constIdx);
ldVal = Builder.CreateInsertElement(ldVal, val, constIdx); // UGH
Builder.CreateStore(ldVal, V);
stInst->eraseFromParent();
}
Expand All @@ -306,8 +315,11 @@ void DynamicIndexingVectorToArray::ReplaceStaticIndexingOnVector(Value *V) {
}

bool DynamicIndexingVectorToArray::needToLower(Value *V) {
// Only needed where vectors aren't supported.
if (SupportsVectors)
return false;
Type *Ty = V->getType()->getPointerElementType();
if (dyn_cast<VectorType>(Ty)) {
if (isa<VectorType>(Ty)) {
if (isa<GlobalVariable>(V) || ReplaceAllVectors) {
return true;
}
Expand Down
18 changes: 11 additions & 7 deletions lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1869,7 +1869,8 @@ bool SROAGlobalAndAllocas(HLModule &HLM, bool bHasDbgInfo) {
// if
// all its users can be transformed, then split up the aggregate into its
// separate elements.
if (ShouldAttemptScalarRepl(AI) && isSafeAllocaToScalarRepl(AI)) {
if (!HLM.GetShaderModel()->IsSM69Plus() && ShouldAttemptScalarRepl(AI) &&
isSafeAllocaToScalarRepl(AI)) {
std::vector<Value *> Elts;
IRBuilder<> Builder(dxilutil::FindAllocaInsertionPt(AI));
bool hasPrecise = HLModule::HasPreciseAttributeWithMetadata(AI);
Expand Down Expand Up @@ -1945,8 +1946,9 @@ bool SROAGlobalAndAllocas(HLModule &HLM, bool bHasDbgInfo) {
continue;
}

// Flat Global vector if no dynamic vector indexing.
bool bFlatVector = !hasDynamicVectorIndexing(GV);
// Flat Global vector if no dynamic vector indexing and pre-6.9.
bool bFlatVector =
!hasDynamicVectorIndexing(GV) && !HLM.GetShaderModel()->IsSM69Plus();

if (bFlatVector) {
GVDbgOffset &dbgOffset = GVDbgOffsetMap[GV];
Expand Down Expand Up @@ -1980,10 +1982,12 @@ bool SROAGlobalAndAllocas(HLModule &HLM, bool bHasDbgInfo) {
} else {
// SROA_Parameter_HLSL has no access to a domtree, if one is needed,
// it'll be generated
SROAed = SROA_Helper::DoScalarReplacement(
GV, Elts, Builder, bFlatVector,
// TODO: set precise.
/*hasPrecise*/ false, typeSys, DL, DeadInsts, /*DT*/ nullptr);
if (!HLM.GetShaderModel()->IsSM69Plus()) {
SROAed = SROA_Helper::DoScalarReplacement(
GV, Elts, Builder, bFlatVector,
// TODO: set precise.
/*hasPrecise*/ false, typeSys, DL, DeadInsts, /*DT*/ nullptr);
}
}

if (SROAed) {
Expand Down
6 changes: 6 additions & 0 deletions lib/Transforms/Scalar/Scalarizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
//
//===----------------------------------------------------------------------===//

#include "dxc/DXIL/DxilModule.h"

#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstVisitor.h"
Expand Down Expand Up @@ -290,6 +292,10 @@ bool Scalarizer::doInitialization(Module &M) {
}

bool Scalarizer::runOnFunction(Function &F) {
if (F.getParent()->HasDxilModule())
if (F.getParent()->GetDxilModule().GetShaderModel()->IsSM69Plus())
return false;

for (Function::iterator BBI = F.begin(), BBE = F.end(); BBI != BBE; ++BBI) {
BasicBlock *BB = BBI;
for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
Expand Down
15 changes: 1 addition & 14 deletions tools/clang/lib/CodeGen/CGExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3713,20 +3713,7 @@ VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
llvm::Value *CondV = CGF.EmitScalarExpr(condExpr);
llvm::Value *LHS = Visit(lhsExpr);
llvm::Value *RHS = Visit(rhsExpr);
if (llvm::VectorType *VT = dyn_cast<llvm::VectorType>(CondV->getType())) {
llvm::VectorType *ResultVT = cast<llvm::VectorType>(LHS->getType());
llvm::Value *result = llvm::UndefValue::get(ResultVT);
for (unsigned i = 0; i < VT->getNumElements(); i++) {
llvm::Value *EltCond = Builder.CreateExtractElement(CondV, i);
llvm::Value *EltL = Builder.CreateExtractElement(LHS, i);
llvm::Value *EltR = Builder.CreateExtractElement(RHS, i);
llvm::Value *EltSelect = Builder.CreateSelect(EltCond, EltL, EltR);
result = Builder.CreateInsertElement(result, EltSelect, i);
}
return result;
} else {
return Builder.CreateSelect(CondV, LHS, RHS);
}
return Builder.CreateSelect(CondV, LHS, RHS);
}
if (hlsl::IsHLSLMatType(E->getType())) {
llvm::Value *Cond = CGF.EmitScalarExpr(condExpr);
Expand Down
8 changes: 6 additions & 2 deletions tools/clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6487,6 +6487,9 @@ bool HLSLExternalSource::MatchArguments(
}
}

std::string profile = m_sema->getLangOpts().HLSLProfile;
const ShaderModel *SM = hlsl::ShaderModel::GetByName(profile.c_str());

// Populate argTypes.
for (size_t i = 0; i <= Args.size(); i++) {
const HLSL_INTRINSIC_ARGUMENT *pArgument = &pIntrinsic->pArgs[i];
Expand Down Expand Up @@ -6657,8 +6660,9 @@ bool HLSLExternalSource::MatchArguments(
}

// Verify that the final results are in bounds.
CAB(uCols > 0 && uCols <= MaxVectorSize && uRows > 0 &&
uRows <= MaxVectorSize,
CAB((uCols > 0 && uRows > 0 &&
((uCols <= MaxVectorSize && uRows <= MaxVectorSize) ||
(SM->IsSM69Plus() && uRows == 1))),
i);

// Const
Expand Down
Loading

0 comments on commit 68a284b

Please sign in to comment.