Skip to content

Commit

Permalink
DAG: Fix assuming f16 is the only 16-bit fp type in concat vector com…
Browse files Browse the repository at this point in the history
…bine (llvm#121637)

This would see if there are mixed integer and FP types and pick an
equivalently sized FP type to use as the vector element type, and only
cast if there were mixed integers. We need to insert a cast if the types
are mixed, which may include different FP types.

Fixes llvm#121601
  • Loading branch information
arsenm authored Jan 6, 2025
1 parent f3590c1 commit d34f7ea
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 17 deletions.
30 changes: 13 additions & 17 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24308,8 +24308,8 @@ static SDValue combineConcatVectorOfScalars(SDNode *N, SelectionDAG &DAG) {
EVT SVT = EVT::getIntegerVT(*DAG.getContext(), OpVT.getSizeInBits());

// Keep track of what we encounter.
bool AnyInteger = false;
bool AnyFP = false;
EVT AnyFPVT;

for (const SDValue &Op : N->ops()) {
if (ISD::BITCAST == Op.getOpcode() &&
!Op.getOperand(0).getValueType().isVector())
Expand All @@ -24323,27 +24323,23 @@ static SDValue combineConcatVectorOfScalars(SDNode *N, SelectionDAG &DAG) {
// If it's neither, bail out, it could be something weird like x86mmx.
EVT LastOpVT = Ops.back().getValueType();
if (LastOpVT.isFloatingPoint())
AnyFP = true;
else if (LastOpVT.isInteger())
AnyInteger = true;
else
AnyFPVT = LastOpVT;
else if (!LastOpVT.isInteger())
return SDValue();
}

// If any of the operands is a floating point scalar bitcast to a vector,
// use floating point types throughout, and bitcast everything.
// Replace UNDEFs by another scalar UNDEF node, of the final desired type.
if (AnyFP) {
SVT = EVT::getFloatingPointVT(OpVT.getSizeInBits());
if (AnyInteger) {
for (SDValue &Op : Ops) {
if (Op.getValueType() == SVT)
continue;
if (Op.isUndef())
Op = DAG.getNode(ISD::UNDEF, DL, SVT);
else
Op = DAG.getBitcast(SVT, Op);
}
if (AnyFPVT != EVT()) {
SVT = AnyFPVT;
for (SDValue &Op : Ops) {
if (Op.getValueType() == SVT)
continue;
if (Op.isUndef())
Op = DAG.getNode(ISD::UNDEF, DL, SVT);
else
Op = DAG.getBitcast(SVT, Op);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mtriple=amdgcn-amd-amdhsa -mcpu=gfx942 < %s | FileCheck %s

define <4 x float> @issue121601(bfloat %fptrunc) {
; CHECK-LABEL: issue121601:
; CHECK: ; %bb.0: ; %bb
; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
; CHECK-NEXT: v_lshlrev_b32_e32 v0, 16, v0
; CHECK-NEXT: v_mov_b32_e32 v1, v0
; CHECK-NEXT: v_mov_b32_e32 v2, 0
; CHECK-NEXT: v_mov_b32_e32 v3, 0
; CHECK-NEXT: s_setpc_b64 s[30:31]
bb:
%bitcast = bitcast bfloat %fptrunc to <1 x bfloat>
%shufflevector = shufflevector <1 x bfloat> %bitcast, <1 x bfloat> zeroinitializer, <2 x i32> zeroinitializer
%fpext = fpext <2 x bfloat> %shufflevector to <2 x float>
%shufflevector1 = shufflevector <2 x float> %fpext, <2 x float> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
ret <4 x float> %shufflevector1
}

0 comments on commit d34f7ea

Please sign in to comment.