Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIR-V] Type inference must realize that a <1 x Type> vector type is not a legal vector type in LLT #124560

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

VyacheslavLevytskyy
Copy link
Contributor

In this PR we account for possible <1 x LLVM Type> input to ensure that we produce legal vector types during type inference.

We modify an LLVM type to conform with future transformations in IRTranslator, if it's a <1 x Type> vector type, replacing it by the element type, because <1 x Type> vector type is not a legal vector type in LLT and IRTranslator will represent it as the scalar eventually.

Copy link

github-actions bot commented Jan 31, 2025

✅ With the latest revision this PR passed the undef deprecator.

@llvmbot
Copy link
Member

llvmbot commented Jan 31, 2025

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

Changes

In this PR we account for possible <1 x LLVM Type> input to ensure that we produce legal vector types during type inference.

We modify an LLVM type to conform with future transformations in IRTranslator, if it's a <1 x Type> vector type, replacing it by the element type, because <1 x Type> vector type is not a legal vector type in LLT and IRTranslator will represent it as the scalar eventually.


Patch is 20.71 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/124560.diff

3 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+33-19)
  • (modified) llvm/lib/Target/SPIRV/SPIRVUtils.h (+22)
  • (added) llvm/test/CodeGen/SPIRV/validate/triton-tut-softmax-kernel.ll (+221)
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 702206b8e0dc56..52614d378c465a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -427,7 +427,7 @@ Type *SPIRVEmitIntrinsics::reconstructType(Value *Op, bool UnknownElemTypeI8,
 
 void SPIRVEmitIntrinsics::buildAssignType(IRBuilder<> &B, Type *Ty,
                                           Value *Arg) {
-  Value *OfType = PoisonValue::get(Ty);
+  Value *OfType = getNormalizedPoisonValue(Ty);
   CallInst *AssignCI = nullptr;
   if (Arg->getType()->isAggregateType() && Ty->isAggregateType() &&
       allowEmitFakeUse(Arg)) {
@@ -447,6 +447,7 @@ void SPIRVEmitIntrinsics::buildAssignType(IRBuilder<> &B, Type *Ty,
 
 void SPIRVEmitIntrinsics::buildAssignPtr(IRBuilder<> &B, Type *ElemTy,
                                          Value *Arg) {
+  ElemTy = normalizeType(ElemTy);
   Value *OfType = PoisonValue::get(ElemTy);
   CallInst *AssignPtrTyCI = GR->findAssignPtrTypeInstr(Arg);
   if (AssignPtrTyCI == nullptr ||
@@ -470,7 +471,7 @@ void SPIRVEmitIntrinsics::updateAssignType(CallInst *AssignCI, Value *Arg,
     return;
 
   // update association with the pointee type
-  Type *ElemTy = OfType->getType();
+  Type *ElemTy = normalizeType(OfType->getType());
   GR->addDeducedElementType(AssignCI, ElemTy);
   GR->addDeducedElementType(Arg, ElemTy);
 }
@@ -490,7 +491,7 @@ CallInst *SPIRVEmitIntrinsics::buildSpvPtrcast(Function *F, Value *Op,
   }
   Type *OpTy = Op->getType();
   SmallVector<Type *, 2> Types = {OpTy, OpTy};
-  SmallVector<Value *, 2> Args = {Op, buildMD(PoisonValue::get(ElemTy)),
+  SmallVector<Value *, 2> Args = {Op, buildMD(getNormalizedPoisonValue(ElemTy)),
                                   B.getInt32(getPointerAddressSpace(OpTy))};
   CallInst *PtrCasted =
       B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
@@ -766,7 +767,7 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
   // remember the found relationship
   if (Ty && !IgnoreKnownType) {
     // specify nested types if needed, otherwise return unchanged
-    GR->addDeducedElementType(I, Ty);
+    GR->addDeducedElementType(I, normalizeType(Ty));
   }
 
   return Ty;
@@ -852,7 +853,7 @@ Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper(
       }
       if (Ty != OpTy) {
         Type *NewTy = VectorType::get(Ty, VecTy->getElementCount());
-        GR->addDeducedCompositeType(U, NewTy);
+        GR->addDeducedCompositeType(U, normalizeType(NewTy));
         return NewTy;
       }
     }
@@ -990,6 +991,7 @@ bool SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionRet(
   if (KnownElemTy)
     return false;
   if (Type *OpElemTy = GR->findDeducedElementType(Op)) {
+    OpElemTy = normalizeType(OpElemTy);
     GR->addDeducedElementType(F, OpElemTy);
     GR->addReturnType(
         F, TypedPointerType::get(OpElemTy,
@@ -1002,7 +1004,7 @@ bool SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionRet(
         continue;
       if (CallInst *AssignCI = GR->findAssignPtrTypeInstr(CI)) {
         if (Type *PrevElemTy = GR->findDeducedElementType(CI)) {
-          updateAssignType(AssignCI, CI, PoisonValue::get(OpElemTy));
+          updateAssignType(AssignCI, CI, getNormalizedPoisonValue(OpElemTy));
           propagateElemType(CI, PrevElemTy, VisitedSubst);
         }
       }
@@ -1162,11 +1164,11 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(
     Type *Ty = AskTy ? AskTy : GR->findDeducedElementType(Op);
     if (Ty == KnownElemTy)
       continue;
-    Value *OpTyVal = PoisonValue::get(KnownElemTy);
+    Value *OpTyVal = getNormalizedPoisonValue(KnownElemTy);
     Type *OpTy = Op->getType();
     if (!Ty || AskTy || isUntypedPointerTy(Ty) || isTodoType(Op)) {
       Type *PrevElemTy = GR->findDeducedElementType(Op);
-      GR->addDeducedElementType(Op, KnownElemTy);
+      GR->addDeducedElementType(Op, normalizeType(KnownElemTy));
       // check if KnownElemTy is complete
       if (!Uncomplete)
         eraseTodoType(Op);
@@ -1492,7 +1494,7 @@ void SPIRVEmitIntrinsics::insertAssignPtrTypeTargetExt(
 
   // Our previous guess about the type seems to be wrong, let's update
   // inferred type according to a new, more precise type information.
-  updateAssignType(AssignCI, V, PoisonValue::get(AssignedType));
+  updateAssignType(AssignCI, V, getNormalizedPoisonValue(AssignedType));
 }
 
 void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
@@ -1507,7 +1509,7 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
     return;
 
   setInsertPointSkippingPhis(B, I);
-  Value *ExpectedElementVal = PoisonValue::get(ExpectedElementType);
+  Value *ExpectedElementVal = getNormalizedPoisonValue(ExpectedElementType);
   MetadataAsValue *VMD = buildMD(ExpectedElementVal);
   unsigned AddressSpace = getPointerAddressSpace(Pointer->getType());
   bool FirstPtrCastOrAssignPtrType = true;
@@ -1653,7 +1655,7 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
       if (!ElemTy) {
         ElemTy = getPointeeTypeByCallInst(DemangledName, CalledF, OpIdx);
         if (ElemTy) {
-          GR->addDeducedElementType(CalledArg, ElemTy);
+          GR->addDeducedElementType(CalledArg, normalizeType(ElemTy));
         } else {
           for (User *U : CalledArg->users()) {
             if (Instruction *Inst = dyn_cast<Instruction>(U)) {
@@ -1704,6 +1706,11 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
 }
 
 Instruction *SPIRVEmitIntrinsics::visitInsertElementInst(InsertElementInst &I) {
+  // If it's a <1 x Type> vector type, don't modify it. It's not a legal vector
+  // type in LLT and IRTranslator will replace it by the scalar.
+  if (isVector1(I.getType()))
+    return &I;
+
   SmallVector<Type *, 4> Types = {I.getType(), I.getOperand(0)->getType(),
                                   I.getOperand(1)->getType(),
                                   I.getOperand(2)->getType()};
@@ -1717,6 +1724,11 @@ Instruction *SPIRVEmitIntrinsics::visitInsertElementInst(InsertElementInst &I) {
 
 Instruction *
 SPIRVEmitIntrinsics::visitExtractElementInst(ExtractElementInst &I) {
+  // If it's a <1 x Type> vector type, don't modify it. It's not a legal vector
+  // type in LLT and IRTranslator will replace it by the scalar.
+  if (isVector1(I.getVectorOperandType()))
+    return &I;
+
   IRBuilder<> B(I.getParent());
   B.SetInsertPoint(&I);
   SmallVector<Type *, 3> Types = {I.getType(), I.getVectorOperandType(),
@@ -1984,8 +1996,9 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
           Type *ElemTy = GR->findDeducedElementType(Op);
           buildAssignPtr(B, ElemTy ? ElemTy : deduceElementType(Op, true), Op);
         } else {
-          CallInst *AssignCI = buildIntrWithMD(Intrinsic::spv_assign_type,
-                                               {OpTy}, Op, Op, {}, B);
+          CallInst *AssignCI =
+              buildIntrWithMD(Intrinsic::spv_assign_type, {OpTy},
+                              getNormalizedPoisonValue(OpTy), Op, {}, B);
           GR->addAssignPtrTypeInstr(Op, AssignCI);
         }
       }
@@ -2034,7 +2047,7 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
       Type *OpTy = Op->getType();
       Value *OpTyVal = Op;
       if (OpTy->isTargetExtTy())
-        OpTyVal = PoisonValue::get(OpTy);
+        OpTyVal = getNormalizedPoisonValue(OpTy);
       CallInst *NewOp =
           buildIntrWithMD(Intrinsic::spv_track_constant,
                           {OpTy, OpTyVal->getType()}, Op, OpTyVal, {}, B);
@@ -2045,7 +2058,7 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
         buildAssignPtr(B, IntegerType::getInt8Ty(I->getContext()), NewOp);
         SmallVector<Type *, 2> Types = {OpTy, OpTy};
         SmallVector<Value *, 2> Args = {
-            NewOp, buildMD(PoisonValue::get(OpElemTy)),
+            NewOp, buildMD(getNormalizedPoisonValue(OpElemTy)),
             B.getInt32(getPointerAddressSpace(OpTy))};
         CallInst *PtrCasted =
             B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
@@ -2178,7 +2191,7 @@ void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
     if (!ElemTy && (ElemTy = deduceFunParamElementType(F, OpIdx)) != nullptr) {
       if (CallInst *AssignCI = GR->findAssignPtrTypeInstr(Arg)) {
         DenseSet<std::pair<Value *, Value *>> VisitedSubst;
-        updateAssignType(AssignCI, Arg, PoisonValue::get(ElemTy));
+        updateAssignType(AssignCI, Arg, getNormalizedPoisonValue(ElemTy));
         propagateElemType(Arg, IntegerType::getInt8Ty(F->getContext()),
                           VisitedSubst);
       } else {
@@ -2232,7 +2245,7 @@ bool SPIRVEmitIntrinsics::processFunctionPointers(Module &M) {
           continue;
         if (II->getIntrinsicID() == Intrinsic::spv_assign_ptr_type ||
             II->getIntrinsicID() == Intrinsic::spv_ptrcast) {
-          updateAssignType(II, &F, PoisonValue::get(FPElemTy));
+          updateAssignType(II, &F, getNormalizedPoisonValue(FPElemTy));
           break;
         }
       }
@@ -2256,7 +2269,7 @@ bool SPIRVEmitIntrinsics::processFunctionPointers(Module &M) {
   for (Function *F : Worklist) {
     SmallVector<Value *> Args;
     for (const auto &Arg : F->args())
-      Args.push_back(PoisonValue::get(Arg.getType()));
+      Args.push_back(getNormalizedPoisonValue(Arg.getType()));
     IRB.CreateCall(F, Args);
   }
   IRB.CreateRetVoid();
@@ -2286,7 +2299,7 @@ void SPIRVEmitIntrinsics::applyDemangledPtrArgTypes(IRBuilder<> &B) {
             buildAssignPtr(B, ElemTy, Arg);
           }
         } else if (isa<Instruction>(Param)) {
-          GR->addDeducedElementType(Param, ElemTy);
+          GR->addDeducedElementType(Param, normalizeType(ElemTy));
           // insertAssignTypeIntrs() will complete buildAssignPtr()
         } else {
           B.SetInsertPoint(CI->getParent()
@@ -2302,6 +2315,7 @@ void SPIRVEmitIntrinsics::applyDemangledPtrArgTypes(IRBuilder<> &B) {
         if (!RefF || !isPointerTy(RefF->getReturnType()) ||
             GR->findDeducedElementType(RefF))
           continue;
+        ElemTy = normalizeType(ElemTy);
         GR->addDeducedElementType(RefF, ElemTy);
         GR->addReturnType(
             RefF, TypedPointerType::get(
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index fd48098257065a..552adf2df7d179 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -383,6 +383,28 @@ inline const Type *unifyPtrType(const Type *Ty) {
   return toTypedPointer(const_cast<Type *>(Ty));
 }
 
+inline bool isVector1(Type *Ty) {
+  auto *FVTy = dyn_cast<FixedVectorType>(Ty);
+  return FVTy && FVTy->getNumElements() == 1;
+}
+
+// Modify an LLVM type to conform with future transformations in IRTranslator.
+// At the moment use cases comprise only a <1 x Type> vector. To extend when/if
+// needed.
+inline Type *normalizeType(Type *Ty) {
+  auto *FVTy = dyn_cast<FixedVectorType>(Ty);
+  if (!FVTy || FVTy->getNumElements() != 1)
+    return Ty;
+  // If it's a <1 x Type> vector type, replace it by the element type, because
+  // it's not a legal vector type in LLT and IRTranslator will represent it as
+  // the scalar eventually.
+  return normalizeType(FVTy->getElementType());
+}
+
+inline PoisonValue *getNormalizedPoisonValue(Type *Ty) {
+  return PoisonValue::get(normalizeType(Ty));
+}
+
 MachineInstr *getVRegDef(MachineRegisterInfo &MRI, Register Reg);
 
 #define SPIRV_BACKEND_SERVICE_FUN_NAME "__spirv_backend_service_fun"
diff --git a/llvm/test/CodeGen/SPIRV/validate/triton-tut-softmax-kernel.ll b/llvm/test/CodeGen/SPIRV/validate/triton-tut-softmax-kernel.ll
new file mode 100644
index 00000000000000..d8a6c85b3d4073
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/validate/triton-tut-softmax-kernel.ll
@@ -0,0 +1,221 @@
+; This is an excerpt from the tutorial of the Triton language converted into
+; LLVM IR via the Triton XPU backend and cleaned of irrelevant details.
+; The only pass criterion is that spirv-val considers output valid.
+
+; Ths particular case is related to translation of <1 x Ty> vectors.
+
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val --target-env spv1.4 %}
+
+define spir_kernel void @softmax_kernel(ptr addrspace(1) nocapture writeonly %0, ptr addrspace(1) nocapture readonly %1, i32 %2, i32 %3, i32 %4, i32 %5, ptr addrspace(3) nocapture %6) {
+  %8 = tail call spir_func i64 @_Z12get_group_idj(i32 0)
+  %9 = trunc i64 %8 to i32
+  %10 = tail call spir_func i64 @_Z14get_num_groupsj(i32 0)
+  %11 = trunc i64 %10 to i32
+  %12 = tail call spir_func i64 @_Z12get_local_idj(i32 0)
+  %13 = trunc i64 %12 to i32
+  %14 = and i32 %13, 255
+  %15 = or disjoint i32 %14, 256
+  %16 = or disjoint i32 %14, 512
+  %17 = or disjoint i32 %14, 768
+  %18 = icmp slt i32 %14, %5
+  %19 = icmp slt i32 %15, %5
+  %20 = icmp slt i32 %16, %5
+  %21 = icmp slt i32 %17, %5
+  %22 = icmp sgt i32 %4, %9
+  br i1 %22, label %.lr.ph, label %._crit_edge
+
+.lr.ph:                                           ; preds = %7
+  %23 = lshr i64 %12, 5
+  %24 = and i32 %13, 31
+  %25 = zext nneg i32 %15 to i64
+  %26 = zext nneg i32 %16 to i64
+  %27 = zext nneg i32 %17 to i64
+  %28 = and i64 %12, 255
+  %29 = and i64 %23, 7
+  %30 = icmp eq i32 %24, 0
+  %31 = getelementptr float, ptr addrspace(3) %6, i64 %29
+  %32 = icmp slt i32 %13, 8
+  %sext = shl i64 %12, 32
+  %33 = ashr exact i64 %sext, 30
+  %34 = getelementptr i8, ptr addrspace(3) %6, i64 %33
+  %35 = and i32 %13, 7
+  %36 = icmp eq i32 %35, 0
+  %37 = and i1 %32, %36
+  br label %38
+
+38:                                               ; preds = %.lr.ph, %123
+  %39 = phi i32 [ %9, %.lr.ph ], [ %124, %123 ]
+  %40 = mul i32 %39, %2
+  %41 = sext i32 %40 to i64
+  %42 = getelementptr float, ptr addrspace(1) %1, i64 %41
+  %43 = getelementptr float, ptr addrspace(1) %42, i64 %25
+  %44 = getelementptr float, ptr addrspace(1) %42, i64 %26
+  %45 = getelementptr float, ptr addrspace(1) %42, i64 %27
+  br i1 %18, label %46, label %49
+
+46:                                               ; preds = %38
+  %47 = getelementptr float, ptr addrspace(1) %42, i64 %28
+  %48 = load <1 x float>, ptr addrspace(1) %47, align 4
+  br label %49
+
+49:                                               ; preds = %46, %38
+  %50 = phi <1 x float> [ %48, %46 ], [ splat (float 0xFFF0000000000000), %38 ]
+  %51 = extractelement <1 x float> %50, i64 0
+  br i1 %19, label %52, label %54
+
+52:                                               ; preds = %49
+  %53 = load <1 x float>, ptr addrspace(1) %43, align 4
+  br label %54
+
+54:                                               ; preds = %52, %49
+  %55 = phi <1 x float> [ %53, %52 ], [ splat (float 0xFFF0000000000000), %49 ]
+  %56 = extractelement <1 x float> %55, i64 0
+  br i1 %20, label %57, label %59
+
+57:                                               ; preds = %54
+  %58 = load <1 x float>, ptr addrspace(1) %44, align 4
+  br label %59
+
+59:                                               ; preds = %57, %54
+  %60 = phi <1 x float> [ %58, %57 ], [ splat (float 0xFFF0000000000000), %54 ]
+  %61 = extractelement <1 x float> %60, i64 0
+  br i1 %21, label %62, label %64
+
+62:                                               ; preds = %59
+  %63 = load <1 x float>, ptr addrspace(1) %45, align 4
+  br label %64
+
+64:                                               ; preds = %62, %59
+  %65 = phi <1 x float> [ %63, %62 ], [ splat (float 0xFFF0000000000000), %59 ]
+  %66 = extractelement <1 x float> %65, i64 0
+  tail call spir_func void @_Z7barrierj(i32 1)
+  %67 = tail call float @llvm.maxnum.f32(float %51, float %56)
+  %68 = tail call float @llvm.maxnum.f32(float %67, float %61)
+  %69 = tail call float @llvm.maxnum.f32(float %68, float %66)
+  %70 = tail call spir_func float @_Z27__spirv_GroupNonUniformFMaxiif(i32 3, i32 0, float %69)
+  br i1 %30, label %71, label %72
+
+71:                                               ; preds = %64
+  store float %70, ptr addrspace(3) %31, align 4
+  br label %72
+
+72:                                               ; preds = %71, %64
+  tail call spir_func void @_Z7barrierj(i32 1)
+  br i1 %32, label %74, label %.thread1
+
+.thread1:                                         ; preds = %72
+  %73 = tail call spir_func float @_Z27__spirv_GroupNonUniformFMaxiifj(i32 3, i32 3, float poison, i32 8)
+  br label %78
+
+74:                                               ; preds = %72
+  %75 = load float, ptr addrspace(3) %34, align 4
+  %76 = tail call spir_func float @_Z27__spirv_GroupNonUniformFMaxiifj(i32 3, i32 3, float %75, i32 8)
+  br i1 %37, label %77, label %78
+
+77:                                               ; preds = %74
+  store float %76, ptr addrspace(3) %34, align 4
+  br label %78
+
+78:                                               ; preds = %.thread1, %77, %74
+  tail call spir_func void @_Z7barrierj(i32 1)
+  %79 = load float, ptr addrspace(3) %6, align 4
+  %80 = fsub float %51, %79
+  %81 = fsub float %56, %79
+  %82 = fsub float %61, %79
+  %83 = fsub float %66, %79
+  %84 = fmul float %80, 0x3FF7154760000000
+  %85 = tail call float @llvm.exp2.f32(float %84)
+  %86 = fmul float %81, 0x3FF7154760000000
+  %87 = tail call float @llvm.exp2.f32(float %86)
+  %88 = fmul float %82, 0x3FF7154760000000
+  %89 = tail call float @llvm.exp2.f32(float %88)
+  %90 = fmul float %83, 0x3FF7154760000000
+  %91 = tail call float @llvm.exp2.f32(float %90)
+  tail call spir_func void @_Z7barrierj(i32 1)
+  %92 = fadd float %85, %87
+  %93 = fadd float %89, %92
+  %94 = fadd float %91, %93
+  %95 = tail call spir_func float @_Z27__spirv_GroupNonUniformFAddiif(i32 3, i32 0, float %94)
+  br i1 %30, label %96, label %97
+
+96:                                               ; preds = %78
+  store float %95, ptr addrspace(3) %31, align 4
+  br label %97
+
+97:                                               ; preds = %96, %78
+  tail call spir_func void @_Z7barrierj(i32 1)
+  br i1 %32, label %99, label %.thread
+
+.thread:                                          ; preds = %97
+  %98 = tail call spir_func float @_Z27__spirv_GroupNonUniformFAddiifj(i32 3, i32 3, float poison, i32 8)
+  br label %103
+
+99:                                               ; preds = %97
+  %100 = load float, ptr addrspace(3) %34, align 4
+  %101 = tail call spir_func float @_Z27__spirv_GroupNonUniformFAddiifj(i32 3, i32 3, float %100, i32 8)
+  br i1 %37, label %102, label %103
+
+102:                                              ; preds = %99
+  store float %101, ptr addrspace(3) %34, align 4
+  br label %103
+
+103:                                              ; preds = %.thread, %102, %99
+  tail call spir_func void @_Z7barrierj(i32 1)
+  %104 = load float, ptr addrspace(3) %6, align 4
+  %105 = fdiv float %87, %104
+  %106 = fdiv float %89, %104
+  %107 = fdiv float %91, %104
+  %108 = mul i32 %39, %3
+  %109 = sext i32 %108 to i64
+  %110 = getelementptr float, ptr addrspace(1) %0, i64 %109
+  %111 = getelementptr float, ptr addrspace(1) %110, i64 %25
+  %112 = getelementptr float, ptr addrspace(1) %110, i64 %26
+  %113 = getelementptr float, ptr addrspace(1) %110, i64 %27
+  br i1 %18, label %114, label %117
+
+114:                                              ; preds = %103
+  %115 = fdiv float %85, %104
+  %116 = getelementptr float, ptr addrspace(1) %110, i64 %28
+  store float %115, ptr addrspace(1) %116, align 4
+  br label %117
+
+117:                                              ; preds = %114, %103
+  br i1 %19, label %118, label %119
+
+118:                                              ; preds = %117
+  store float %105, ptr addrspace(1) %111, align 4
+  br label %119
+
+119:                                              ; preds = %118, %117
+  br i1 %20, label %120, label %121
+
+120:                                              ; preds = %119
+  store float %106, ptr addrspace(1) %112, align 4
+  br label %121
+
+121:                                              ; preds = %120, %119
+  br i1 %21, label %122, label %123
+
+122:                                              ; preds = %121
+  store float %107, ptr addrspace(1) %113, align 4
+  br label %123
+
+123:                                              ; preds = %122, %121
+  %124 = add i32 %39, %11
+  %125 = icmp slt i32 %124...
[truncated]

Copy link
Contributor

@MrSidims MrSidims left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a good starter, thanks!
There are several other issues to fix:

  1. Apart of insert/extract instructions we should handle shufflevector instruction, please consider it fixing within this patch;
  2. Apart of insert/extract instructions is legal (though not recommended) to use GEP on vectors. In practice with the switch to opaque pointers it should never happen, but theoretically there can be a GEP(0, 0) to <1 x ... > vector followed by load. In this case backend should remove zero-GEP and replace load with the scalar register.
  3. This is a grey area, but I do believe, that if a function declaration with external linkage has <1 x ...> parameter type, then instead of lowering this type to scalar the backend should emit an error. The reason is the following: lets imagine we have some language builtin which implementation we expect to be linked way after the SPIR-V backend (lets say after SPIR-V was consumed by llvm-spirv emitting LLVM IR). Compiler's library if written on LLVM IR then can expect parameter type to be a vector, while it's not a vector anymore, resulting in an error, which is harder to debug, then an error, that SPIR-V backend would emit.
  4. There are some special intrinsics, like vector reduce, gather/scatter etc. They are not yet supported by the backend so this special case can be handled along with the support for them. Question is: are there any other intrinsics that require special vector handling supported by the backend?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants