Skip to content

Commit

Permalink
Fix bug in occ printer (#554)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanglx13 authored Apr 6, 2024
1 parent 6387049 commit f8d2175
Showing 1 changed file with 33 additions and 7 deletions.
40 changes: 33 additions & 7 deletions scripts/amd/occ.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,38 @@
rm -rf ~/.triton/cache/

export MLIR_ENABLE_DUMP=1
export LLVM_IR_ENABLE_DUMP=1
export AMDGCN_ENABLE_DUMP=1
## Assume CDNA arch
SIMD=4
LDS_SIZE=65536
TOTAL_VGPR=512

get_occ_per_CU() {
## $1: vgpr count
vgpr=$1
occPerEU=$((TOTAL_VGPR/vgpr))
if [[ $vgpr -gt 256 ]]; then
occPerEU=1
elif [[ $vgpr -gt 168 ]]; then
occPerEU=2
elif [[ $vgpr -gt 128 ]]; then
occPerEU=3
elif [[ $vgpr -gt 96 ]]; then
occPerEU=4
elif [[ $vgpr -gt 80 ]]; then
occPerEU=5
elif [[ $vgpr -gt 72 ]]; then
occPerEU=6
elif [[ $vgpr -gt 64 ]]; then
occPerEU=7
else
occPerEU=8
fi

occPerCU=$((occPerEU*SIMD/num_warps))
echo $occPerCU
}

$1 > output.mlir 2>&1

LDS_line=$(sed -n '/triton_gpu\.shared\ /p' output.mlir | tail -n 1 | grep -o 'triton_gpu.shared = [0-9]*')
Expand All @@ -26,13 +51,14 @@ SPILLs=$(sed -n '/vgpr_spill/p' output.mlir | tail -n 1 | awk '{print $2}')

echo "VGPRS: $VGPRs (spill: $SPILLs)"

occ_LDS=$((LDS_SIZE/LDS*num_warps/SIMD))
occ_vgpr=$((TOTAL_VGPR/VGPRs))
occ=$occ_vgpr
if [ $occ_LDS -lt $occ_vgpr ];then
occ=$occ_LDS
occLDSPerCU=$((LDS_SIZE/LDS))
occVgprPerCU=$(get_occ_per_CU $VGPRs)
occPerCU=$occVgprPerCU
if [ $occLDSPerCU -lt $occVgprPerCU ];then
occPerCU=$occLDSPerCU
fi
echo "occ: $occ waves/SIMD (occ_LDS: $occ_LDS, occ_vgpr: $occ_vgpr)"
occPerEU=$((occPerCU*num_warps/SIMD))
echo "occupancy: $occPerEU waves/SIMD or $occPerCU workgroups/CU (occLDSPerCU: $occLDSPerCU, occVgprPerCU: $occVgprPerCU)"

perf=$(tail -n 2 output.mlir)
echo "$perf"
Expand Down

0 comments on commit f8d2175

Please sign in to comment.